mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
[Partner]: Add metadata to stream response (#22716)
Adds `response_metadata` to stream responses from OpenAI. This is returned with `invoke` normally, but wasn't implemented for `stream`. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
42a379c75c
commit
c437b1aab7
@ -478,7 +478,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
with self.client.create(messages=message_dicts, **params) as response:
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
@ -490,7 +490,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
@ -501,24 +501,29 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text, chunk=chunk, logprobs=logprobs
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
yield generation_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -596,7 +601,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
response = await self.async_client.create(messages=message_dicts, **params)
|
||||
async with response:
|
||||
async for chunk in response:
|
||||
@ -609,7 +614,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
@ -620,24 +625,31 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
token=generation_chunk.text,
|
||||
chunk=generation_chunk,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
yield chunk
|
||||
yield generation_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -5,7 +5,12 @@ from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
@ -170,6 +175,8 @@ def test_openai_streaming(llm: AzureChatOpenAI) -> None:
|
||||
for chunk in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@ -180,6 +187,8 @@ async def test_openai_astream(llm: AzureChatOpenAI) -> None:
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@ -217,6 +226,7 @@ async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@ -225,6 +235,7 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None:
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need tool calling model deployed on azure")
|
||||
|
@ -351,20 +351,24 @@ def test_stream() -> None:
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
|
||||
# check token usage
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
chunks_with_response_metadata = 0
|
||||
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunks_with_token_counts != 1:
|
||||
if chunk.response_metadata:
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"Expected exactly one chunk with metadata. "
|
||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
@ -384,20 +388,24 @@ async def test_astream() -> None:
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
|
||||
# check token usage
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
chunks_with_response_metadata = 0
|
||||
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunks_with_token_counts != 1:
|
||||
if chunk.response_metadata:
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"Expected exactly one chunk with metadata. "
|
||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
@ -442,6 +450,7 @@ async def test_ainvoke() -> None:
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
@ -450,6 +459,7 @@ def test_invoke() -> None:
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
def test_response_metadata() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user