fix `ChatOpenAI.agenerate` (#1504)

fix-searx
Ankush Gola 1 year ago committed by GitHub
parent 4f41e20f09
commit 27104d4921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,17 +46,14 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
"""Top Level call"""
results = []
for m in messages:
results.append(self._generate(m, stop=stop))
results = [self._generate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
results = []
for m in messages:
results.append(self._generate(m, stop=stop))
"""Top Level call"""
results = [await self._agenerate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])
def generate_prompt(

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import sys
from typing import Any, Callable, Dict, List, Mapping, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
class ChatOpenAI(BaseChatModel, BaseModel):
"""Wrapper around OpenAI Chat large language models.
@ -215,12 +224,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
inner_completion = ""
role = "assistant"
@ -240,22 +244,23 @@ class ChatOpenAI(BaseChatModel, BaseModel):
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
return _create_chat_result(response)
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
inner_completion = ""
role = "assistant"
@ -281,15 +286,10 @@ class ChatOpenAI(BaseChatModel, BaseModel):
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
full_response = await acompletion_with_retry(
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
generations = []
for res in full_response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
return _create_chat_result(response)
@property
def _identifying_params(self) -> Mapping[str, Any]:

@ -87,3 +87,44 @@ def test_chat_openai_invalid_streaming_params() -> None:
temperature=0,
n=5,
)
@pytest.mark.asyncio
async def test_async_chat_openai() -> None:
"""Test async generation."""
chat = ChatOpenAI(max_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_async_chat_openai_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

@ -54,7 +54,7 @@ def test_openai_stop_error() -> None:
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an OpenAPI LLM."""
"""Test saving/loading an OpenAI LLM."""
llm = OpenAI(max_tokens=10)
llm.save(file_path=tmp_path / "openai.yaml")
loaded_llm = load_llm(tmp_path / "openai.yaml")

Loading…
Cancel
Save