diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index d4f37f5398..cf502aa6f1 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -478,7 +478,7 @@ class BaseChatOpenAI(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} - default_chunk_class = AIMessageChunk + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk with self.client.create(messages=message_dicts, **params) as response: for chunk in response: if not isinstance(chunk, dict): @@ -490,7 +490,7 @@ class BaseChatOpenAI(BaseChatModel): output_tokens=token_usage.get("completion_tokens", 0), total_tokens=token_usage.get("total_tokens", 0), ) - chunk = ChatGenerationChunk( + generation_chunk = ChatGenerationChunk( message=default_chunk_class( content="", usage_metadata=usage_metadata ) @@ -501,24 +501,29 @@ class BaseChatOpenAI(BaseChatModel): choice = chunk["choices"][0] if choice["delta"] is None: continue - chunk = _convert_delta_to_message_chunk( + message_chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None ) if run_manager: run_manager.on_llm_new_token( - chunk.text, chunk=chunk, logprobs=logprobs + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs ) - yield chunk + yield generation_chunk def _generate( self, @@ -596,7 +601,7 @@ class BaseChatOpenAI(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} - default_chunk_class = AIMessageChunk + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk response = await self.async_client.create(messages=message_dicts, **params) async with response: async for chunk in response: @@ -609,7 +614,7 @@ class BaseChatOpenAI(BaseChatModel): output_tokens=token_usage.get("completion_tokens", 0), total_tokens=token_usage.get("total_tokens", 0), ) - chunk = ChatGenerationChunk( + generation_chunk = ChatGenerationChunk( message=default_chunk_class( content="", usage_metadata=usage_metadata ) @@ -620,24 +625,31 @@ class BaseChatOpenAI(BaseChatModel): choice = chunk["choices"][0] if choice["delta"] is None: continue - chunk = _convert_delta_to_message_chunk( + message_chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None ) if run_manager: await run_manager.on_llm_new_token( - token=chunk.text, chunk=chunk, logprobs=logprobs + token=generation_chunk.text, + chunk=generation_chunk, + logprobs=logprobs, ) - yield chunk + yield generation_chunk async def _agenerate( self, diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 2cd97fd0cb..04ef044a9d 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -5,7 +5,12 @@ from typing import Any, Optional import pytest from langchain_core.callbacks import CallbackManager -from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + HumanMessage, +) from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.pydantic_v1 import BaseModel @@ -170,6 +175,8 @@ def test_openai_streaming(llm: AzureChatOpenAI) -> None: for chunk in llm.stream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.response_metadata.get("model_name") is not None @pytest.mark.scheduled @@ -180,6 +187,8 @@ async def test_openai_astream(llm: AzureChatOpenAI) -> None: async for chunk in llm.astream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.response_metadata.get("model_name") is not None @pytest.mark.scheduled @@ -217,6 +226,7 @@ async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None: result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) + assert result.response_metadata.get("model_name") is not None @pytest.mark.scheduled @@ -225,6 +235,7 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + assert result.response_metadata.get("model_name") is not None @pytest.mark.skip(reason="Need tool calling model deployed on azure") diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 5a31bc76f9..10273ea3c8 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -351,20 +351,24 @@ def test_stream() -> None: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) assert full.response_metadata.get("finish_reason") is not None + assert full.response_metadata.get("model_name") is not None # check token usage aggregate: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 + chunks_with_response_metadata = 0 for chunk in llm.stream("Hello", stream_options={"include_usage": True}): assert isinstance(chunk.content, str) aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) if chunk.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if chunk.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with metadata. " + "AIMessageChunk aggregation can add these metadata. Check that " "this is behaving properly." ) assert isinstance(aggregate, AIMessageChunk) @@ -384,20 +388,24 @@ async def test_astream() -> None: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) assert full.response_metadata.get("finish_reason") is not None + assert full.response_metadata.get("model_name") is not None # check token usage aggregate: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 + chunks_with_response_metadata = 0 async for chunk in llm.astream("Hello", stream_options={"include_usage": True}): assert isinstance(chunk.content, str) aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) if chunk.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if chunk.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with metadata. " + "AIMessageChunk aggregation can add these metadata. Check that " "this is behaving properly." ) assert isinstance(aggregate, AIMessageChunk) @@ -442,6 +450,7 @@ async def test_ainvoke() -> None: result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) + assert result.response_metadata.get("model_name") is not None def test_invoke() -> None: @@ -450,6 +459,7 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + assert result.response_metadata.get("model_name") is not None def test_response_metadata() -> None: