(ChatOpenAI) Add model_name to LLMResult.llm_output (#1960)

This makes sure OpenAI and ChatOpenAI have the same llm_output, and
allow tracking usage per model. Same work for OpenAI was done in
https://github.com/hwchase17/langchain/pull/1713.
This commit is contained in:
Mario Kostelac 2023-03-24 16:51:16 +01:00 committed by GitHub
parent 6e0d3880df
commit e7d6de6b1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 13 deletions

View File

@ -91,16 +91,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict return message_dict
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
llm_output = {"token_usage": response["usage"]}
return ChatResult(generations=generations, llm_output=llm_output)
class ChatOpenAI(BaseChatModel, BaseModel): class ChatOpenAI(BaseChatModel, BaseModel):
"""Wrapper around OpenAI Chat large language models. """Wrapper around OpenAI Chat large language models.
@ -237,7 +227,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
overall_token_usage[k] += v overall_token_usage[k] += v
else: else:
overall_token_usage[k] = v overall_token_usage[k] = v
return {"token_usage": overall_token_usage} return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _generate( def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self, messages: List[BaseMessage], stop: Optional[List[str]] = None
@ -262,7 +252,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
) )
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params) response = self.completion_with_retry(messages=message_dicts, **params)
return _create_chat_result(response) return self._create_chat_result(response)
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: List[BaseMessage], stop: Optional[List[str]]
@ -275,6 +265,15 @@ class ChatOpenAI(BaseChatModel, BaseModel):
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
llm_output = {"token_usage": response["usage"], "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)
async def _agenerate( async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult: ) -> ChatResult:
@ -307,7 +306,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
response = await acompletion_with_retry( response = await acompletion_with_retry(
self, messages=message_dicts, **params self, messages=message_dicts, **params
) )
return _create_chat_result(response) return self._create_chat_result(response)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:

View File

@ -1,5 +1,6 @@
"""Test ChatOpenAI wrapper.""" """Test ChatOpenAI wrapper."""
import pytest import pytest
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
@ -78,6 +79,24 @@ def test_chat_openai_streaming() -> None:
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
def test_chat_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model_name
def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model_name
def test_chat_openai_invalid_streaming_params() -> None: def test_chat_openai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback.""" """Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError): with pytest.raises(ValueError):