Record system fingerprint chat openai (#12960)

pull/12963/head
Bagatur 11 months ago committed by GitHub
parent 8e0cb2eb84
commit 4f7dff9d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -342,6 +342,7 @@ class ChatOpenAI(BaseChatModel):
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
if output is None:
# Happens in streaming
@ -352,7 +353,12 @@ class ChatOpenAI(BaseChatModel):
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model_name}
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
def _stream(
self,
@ -430,7 +436,11 @@ class ChatOpenAI(BaseChatModel):
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(

@ -58,6 +58,8 @@ def test_chat_openai_generate() -> None:
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
assert response.llm_output
assert "system_fingerprint" in response.llm_output
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
@ -163,6 +165,8 @@ async def test_async_chat_openai() -> None:
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
assert response.llm_output
assert "system_fingerprint" in response.llm_output
for generations in response.generations:
assert len(generations) == 2
for generation in generations:

Loading…
Cancel
Save