forked from Archives/langchain
fix ChatOpenAI.agenerate
(#1504)
This commit is contained in:
parent
4f41e20f09
commit
27104d4921
@ -46,17 +46,14 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = []
|
||||
for m in messages:
|
||||
results.append(self._generate(m, stop=stop))
|
||||
results = [self._generate(m, stop=stop) for m in messages]
|
||||
return LLMResult(generations=[res.generations for res in results])
|
||||
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
results = []
|
||||
for m in messages:
|
||||
results.append(self._generate(m, stop=stop))
|
||||
"""Top Level call"""
|
||||
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||
return LLMResult(generations=[res.generations for res in results])
|
||||
|
||||
def generate_prompt(
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from tenacity import (
|
||||
@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
"""Wrapper around OpenAI Chat large language models.
|
||||
|
||||
@ -215,12 +224,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
@ -240,22 +244,23 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
return _create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
@ -281,15 +286,10 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
else:
|
||||
full_response = await acompletion_with_retry(
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
)
|
||||
generations = []
|
||||
for res in full_response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
return _create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
|
@ -87,3 +87,44 @@ def test_chat_openai_invalid_streaming_params() -> None:
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
@ -54,7 +54,7 @@ def test_openai_stop_error() -> None:
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an OpenAPI LLM."""
|
||||
"""Test saving/loading an OpenAI LLM."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
llm.save(file_path=tmp_path / "openai.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
||||
|
Loading…
Reference in New Issue
Block a user