(OpenAI) Add model_name to LLMResult.llm_output (#1713)

Given that different models have very different latencies and pricings,
it's benefitial to pass the information about the model that generated
the response. Such information allows implementing custom callback
managers and track usage and price per model.

Addresses https://github.com/hwchase17/langchain/issues/1557.
This commit is contained in:
Mario Kostelac 2023-03-17 05:55:55 +01:00 committed by GitHub
parent 8a95fdaee1
commit aff44d0a98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 5 deletions

View File

@ -362,9 +362,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
for choice in sub_choices
]
)
return LLMResult(
generations=generations, llm_output={"token_usage": token_usage}
)
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return LLMResult(generations=generations, llm_output=llm_output)
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
"""Call OpenAI with streaming flag and return the resulting generator.
@ -643,11 +642,15 @@ class OpenAIChat(BaseLLM, BaseModel):
)
else:
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
}
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output={"token_usage": full_response["usage"]},
llm_output=llm_output,
)
async def _agenerate(
@ -679,11 +682,15 @@ class OpenAIChat(BaseLLM, BaseModel):
full_response = await acompletion_with_retry(
self, messages=messages, **params
)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
}
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output={"token_usage": full_response["usage"]},
llm_output=llm_output,
)
@property

View File

@ -35,6 +35,14 @@ def test_openai_extra_kwargs() -> None:
OpenAI(foo=3, model_kwargs={"foo": 2})
def test_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
llm = OpenAI(max_tokens=10)
llm_result = llm.generate(["Hello, how are you?"])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == llm.model_name
def test_openai_stop_valid() -> None:
"""Test openai stop logic on valid configuration."""
query = "write an ordered list of five items"