[Partner]: Add metadata to stream response (#22716)

Adds `response_metadata` to stream responses from OpenAI. This is
returned with `invoke` normally, but wasn't implemented for `stream`.

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Hakan Özdemir 2024-06-17 16:46:50 +03:00 committed by GitHub
parent 42a379c75c
commit c437b1aab7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 23 deletions

View File

@ -478,7 +478,7 @@ class BaseChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(messages=message_dicts, **params) as response:
for chunk in response:
if not isinstance(chunk, dict):
@ -490,7 +490,7 @@ class BaseChatOpenAI(BaseChatModel):
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
@ -501,24 +501,29 @@ class BaseChatOpenAI(BaseChatModel):
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
yield chunk
yield generation_chunk
def _generate(
self,
@ -596,7 +601,7 @@ class BaseChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(messages=message_dicts, **params)
async with response:
async for chunk in response:
@ -609,7 +614,7 @@ class BaseChatOpenAI(BaseChatModel):
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
@ -620,24 +625,31 @@ class BaseChatOpenAI(BaseChatModel):
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
)
yield chunk
yield generation_chunk
async def _agenerate(
self,

View File

@ -5,7 +5,12 @@ from typing import Any, Optional
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.pydantic_v1 import BaseModel
@ -170,6 +175,8 @@ def test_openai_streaming(llm: AzureChatOpenAI) -> None:
for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("model_name") is not None
@pytest.mark.scheduled
@ -180,6 +187,8 @@ async def test_openai_astream(llm: AzureChatOpenAI) -> None:
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("model_name") is not None
@pytest.mark.scheduled
@ -217,6 +226,7 @@ async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None
@pytest.mark.scheduled
@ -225,6 +235,7 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None
@pytest.mark.skip(reason="Need tool calling model deployed on azure")

View File

@ -351,20 +351,24 @@ def test_stream() -> None:
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
if chunk.response_metadata:
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
assert isinstance(aggregate, AIMessageChunk)
@ -384,20 +388,24 @@ async def test_astream() -> None:
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
if chunk.response_metadata:
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
assert isinstance(aggregate, AIMessageChunk)
@ -442,6 +450,7 @@ async def test_ainvoke() -> None:
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None
def test_invoke() -> None:
@ -450,6 +459,7 @@ def test_invoke() -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None
def test_response_metadata() -> None: