From 27104d49218b163b59425bb0a92cc7894860f887 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Tue, 7 Mar 2023 15:22:05 -0800 Subject: [PATCH] fix `ChatOpenAI.agenerate` (#1504) --- langchain/chat_models/base.py | 9 ++-- langchain/chat_models/openai.py | 46 +++++++++---------- .../chat_models/test_openai.py | 41 +++++++++++++++++ tests/integration_tests/llms/test_openai.py | 2 +- 4 files changed, 68 insertions(+), 30 deletions(-) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 91b3744af9..fd23e0409c 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -46,17 +46,14 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC): self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None ) -> LLMResult: """Top Level call""" - results = [] - for m in messages: - results.append(self._generate(m, stop=stop)) + results = [self._generate(m, stop=stop) for m in messages] return LLMResult(generations=[res.generations for res in results]) async def agenerate( self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None ) -> LLMResult: - results = [] - for m in messages: - results.append(self._generate(m, stop=stop)) + """Top Level call""" + results = [await self._agenerate(m, stop=stop) for m in messages] return LLMResult(generations=[res.generations for res in results]) def generate_prompt( diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index e7910fdb38..5dd1760c44 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging import sys -from typing import Any, Callable, Dict, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple from pydantic import BaseModel, Extra, Field, root_validator from tenacity import ( @@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict +def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + gen = ChatGeneration(message=message) + generations.append(gen) + return ChatResult(generations=generations) + + class ChatOpenAI(BaseChatModel, BaseModel): """Wrapper around OpenAI Chat large language models. @@ -215,12 +224,7 @@ class ChatOpenAI(BaseChatModel, BaseModel): def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> ChatResult: - params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - message_dicts = [_convert_message_to_dict(m) for m in messages] + message_dicts, params = self._create_message_dicts(messages, stop) if self.streaming: inner_completion = "" role = "assistant" @@ -240,22 +244,23 @@ class ChatOpenAI(BaseChatModel, BaseModel): ) return ChatResult(generations=[ChatGeneration(message=message)]) response = self.completion_with_retry(messages=message_dicts, **params) - generations = [] - for res in response["choices"]: - message = _convert_dict_to_message(res["message"]) - gen = ChatGeneration(message=message) - generations.append(gen) - return ChatResult(generations=generations) + return _create_chat_result(response) - async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts, params + + async def _agenerate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) if self.streaming: inner_completion = "" role = "assistant" @@ -281,15 +286,10 @@ class ChatOpenAI(BaseChatModel, BaseModel): ) return ChatResult(generations=[ChatGeneration(message=message)]) else: - full_response = await acompletion_with_retry( + response = await acompletion_with_retry( self, messages=message_dicts, **params ) - generations = [] - for res in full_response["choices"]: - message = _convert_dict_to_message(res["message"]) - gen = ChatGeneration(message=message) - generations.append(gen) - return ChatResult(generations=generations) + return _create_chat_result(response) @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index ab09288eb5..347c6a76c6 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -87,3 +87,44 @@ def test_chat_openai_invalid_streaming_params() -> None: temperature=0, n=5, ) + + +@pytest.mark.asyncio +async def test_async_chat_openai() -> None: + """Test async generation.""" + chat = ChatOpenAI(max_tokens=10, n=2) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 2 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +@pytest.mark.asyncio +async def test_async_chat_openai_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatOpenAI( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 818068b674..4f80565e1e 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -54,7 +54,7 @@ def test_openai_stop_error() -> None: def test_saving_loading_llm(tmp_path: Path) -> None: - """Test saving/loading an OpenAPI LLM.""" + """Test saving/loading an OpenAI LLM.""" llm = OpenAI(max_tokens=10) llm.save(file_path=tmp_path / "openai.yaml") loaded_llm = load_llm(tmp_path / "openai.yaml")