fix `ChatOpenAI.agenerate` (#1504)

fix-searx
Ankush Gola 1 year ago committed by GitHub
parent 4f41e20f09
commit 27104d4921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,17 +46,14 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
results = [] results = [self._generate(m, stop=stop) for m in messages]
for m in messages:
results.append(self._generate(m, stop=stop))
return LLMResult(generations=[res.generations for res in results]) return LLMResult(generations=[res.generations for res in results])
async def agenerate( async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
results = [] """Top Level call"""
for m in messages: results = [await self._agenerate(m, stop=stop) for m in messages]
results.append(self._generate(m, stop=stop))
return LLMResult(generations=[res.generations for res in results]) return LLMResult(generations=[res.generations for res in results])
def generate_prompt( def generate_prompt(

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import sys import sys
from typing import Any, Callable, Dict, List, Mapping, Optional from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import ( from tenacity import (
@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict return message_dict
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
class ChatOpenAI(BaseChatModel, BaseModel): class ChatOpenAI(BaseChatModel, BaseModel):
"""Wrapper around OpenAI Chat large language models. """Wrapper around OpenAI Chat large language models.
@ -215,12 +224,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
def _generate( def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult: ) -> ChatResult:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} message_dicts, params = self._create_message_dicts(messages, stop)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
if self.streaming: if self.streaming:
inner_completion = "" inner_completion = ""
role = "assistant" role = "assistant"
@ -240,22 +244,23 @@ class ChatOpenAI(BaseChatModel, BaseModel):
) )
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params) response = self.completion_with_retry(messages=message_dicts, **params)
generations = [] return _create_chat_result(response)
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
async def _agenerate( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> ChatResult: ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming: if self.streaming:
inner_completion = "" inner_completion = ""
role = "assistant" role = "assistant"
@ -281,15 +286,10 @@ class ChatOpenAI(BaseChatModel, BaseModel):
) )
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
else: else:
full_response = await acompletion_with_retry( response = await acompletion_with_retry(
self, messages=message_dicts, **params self, messages=message_dicts, **params
) )
generations = [] return _create_chat_result(response)
for res in full_response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:

@ -87,3 +87,44 @@ def test_chat_openai_invalid_streaming_params() -> None:
temperature=0, temperature=0,
n=5, n=5,
) )
@pytest.mark.asyncio
async def test_async_chat_openai() -> None:
"""Test async generation."""
chat = ChatOpenAI(max_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_async_chat_openai_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

@ -54,7 +54,7 @@ def test_openai_stop_error() -> None:
def test_saving_loading_llm(tmp_path: Path) -> None: def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an OpenAPI LLM.""" """Test saving/loading an OpenAI LLM."""
llm = OpenAI(max_tokens=10) llm = OpenAI(max_tokens=10)
llm.save(file_path=tmp_path / "openai.yaml") llm.save(file_path=tmp_path / "openai.yaml")
loaded_llm = load_llm(tmp_path / "openai.yaml") loaded_llm = load_llm(tmp_path / "openai.yaml")

Loading…
Cancel
Save