mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
(ChatOpenAI) Add model_name to LLMResult.llm_output (#1960)
This makes sure OpenAI and ChatOpenAI have the same llm_output, and allow tracking usage per model. Same work for OpenAI was done in https://github.com/hwchase17/langchain/pull/1713.
This commit is contained in:
parent
6e0d3880df
commit
e7d6de6b1c
@ -91,16 +91,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
|
||||||
generations = []
|
|
||||||
for res in response["choices"]:
|
|
||||||
message = _convert_dict_to_message(res["message"])
|
|
||||||
gen = ChatGeneration(message=message)
|
|
||||||
generations.append(gen)
|
|
||||||
llm_output = {"token_usage": response["usage"]}
|
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatOpenAI(BaseChatModel, BaseModel):
|
class ChatOpenAI(BaseChatModel, BaseModel):
|
||||||
"""Wrapper around OpenAI Chat large language models.
|
"""Wrapper around OpenAI Chat large language models.
|
||||||
|
|
||||||
@ -237,7 +227,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
overall_token_usage[k] += v
|
overall_token_usage[k] += v
|
||||||
else:
|
else:
|
||||||
overall_token_usage[k] = v
|
overall_token_usage[k] = v
|
||||||
return {"token_usage": overall_token_usage}
|
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
@ -262,7 +252,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
)
|
)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||||
return _create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
@ -275,6 +265,15 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
return message_dicts, params
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for res in response["choices"]:
|
||||||
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
gen = ChatGeneration(message=message)
|
||||||
|
generations.append(gen)
|
||||||
|
llm_output = {"token_usage": response["usage"], "model_name": self.model_name}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
@ -307,7 +306,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
response = await acompletion_with_retry(
|
response = await acompletion_with_retry(
|
||||||
self, messages=message_dicts, **params
|
self, messages=message_dicts, **params
|
||||||
)
|
)
|
||||||
return _create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test ChatOpenAI wrapper."""
|
"""Test ChatOpenAI wrapper."""
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.callbacks.base import CallbackManager
|
from langchain.callbacks.base import CallbackManager
|
||||||
@ -78,6 +79,24 @@ def test_chat_openai_streaming() -> None:
|
|||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_openai_llm_output_contains_model_name() -> None:
|
||||||
|
"""Test llm_output contains model_name."""
|
||||||
|
chat = ChatOpenAI(max_tokens=10)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
|
||||||
|
"""Test llm_output contains model_name."""
|
||||||
|
chat = ChatOpenAI(max_tokens=10, streaming=True)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||||
|
|
||||||
|
|
||||||
def test_chat_openai_invalid_streaming_params() -> None:
|
def test_chat_openai_invalid_streaming_params() -> None:
|
||||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user