diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 24d19d7a..900e89ff 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -91,16 +91,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: - generations = [] - for res in response["choices"]: - message = _convert_dict_to_message(res["message"]) - gen = ChatGeneration(message=message) - generations.append(gen) - llm_output = {"token_usage": response["usage"]} - return ChatResult(generations=generations, llm_output=llm_output) - - class ChatOpenAI(BaseChatModel, BaseModel): """Wrapper around OpenAI Chat large language models. @@ -237,7 +227,7 @@ class ChatOpenAI(BaseChatModel, BaseModel): overall_token_usage[k] += v else: overall_token_usage[k] = v - return {"token_usage": overall_token_usage} + return {"token_usage": overall_token_usage, "model_name": self.model_name} def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None @@ -262,7 +252,7 @@ class ChatOpenAI(BaseChatModel, BaseModel): ) return ChatResult(generations=[ChatGeneration(message=message)]) response = self.completion_with_retry(messages=message_dicts, **params) - return _create_chat_result(response) + return self._create_chat_result(response) def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] @@ -275,6 +265,15 @@ class ChatOpenAI(BaseChatModel, BaseModel): message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + gen = ChatGeneration(message=message) + generations.append(gen) + llm_output = {"token_usage": response["usage"], "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> ChatResult: @@ -307,7 +306,7 @@ class ChatOpenAI(BaseChatModel, BaseModel): response = await acompletion_with_retry( self, messages=message_dicts, **params ) - return _create_chat_result(response) + return self._create_chat_result(response) @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index 347c6a76..06394ebc 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -1,5 +1,6 @@ """Test ChatOpenAI wrapper.""" + import pytest from langchain.callbacks.base import CallbackManager @@ -78,6 +79,24 @@ def test_chat_openai_streaming() -> None: assert isinstance(response, BaseMessage) +def test_chat_openai_llm_output_contains_model_name() -> None: + """Test llm_output contains model_name.""" + chat = ChatOpenAI(max_tokens=10) + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert llm_result.llm_output["model_name"] == chat.model_name + + +def test_chat_openai_streaming_llm_output_contains_model_name() -> None: + """Test llm_output contains model_name.""" + chat = ChatOpenAI(max_tokens=10, streaming=True) + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert llm_result.llm_output["model_name"] == chat.model_name + + def test_chat_openai_invalid_streaming_params() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" with pytest.raises(ValueError):