diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index cf502aa6f1..2df447b70f 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -495,6 +495,7 @@ class BaseChatOpenAI(BaseChatModel): content="", usage_metadata=usage_metadata ) ) + logprobs = None else: continue else: @@ -619,6 +620,7 @@ class BaseChatOpenAI(BaseChatModel): content="", usage_metadata=usage_metadata ) ) + logprobs = None else: continue else: @@ -1386,11 +1388,11 @@ class ChatOpenAI(BaseChatOpenAI): {'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33} - When streaming, set the ``stream_options`` model kwarg: + When streaming, set the ``stream_usage`` kwarg: .. code-block:: python - stream = llm.stream(messages, stream_options={"include_usage": True}) + stream = llm.stream(messages, stream_usage=True) full = next(stream) for chunk in stream: full += chunk @@ -1400,7 +1402,7 @@ class ChatOpenAI(BaseChatOpenAI): {'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33} - Alternatively, setting ``stream_options`` when instantiating the model can be + Alternatively, setting ``stream_usage`` when instantiating the model can be useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using methods like ``.with_structured_output``, which generate chains under the hood. @@ -1409,7 +1411,7 @@ class ChatOpenAI(BaseChatOpenAI): llm = ChatOpenAI( model="gpt-4o", - model_kwargs={"stream_options": {"include_usage": True}}, + stream_usage=True, ) structured_llm = llm.with_structured_output(...) @@ -1446,6 +1448,11 @@ class ChatOpenAI(BaseChatOpenAI): """ # noqa: E501 + stream_usage: bool = False + """Whether to include usage metadata in streaming output. If True, additional + message chunks will be generated during the stream including usage metadata. + """ + @property def lc_secrets(self) -> Dict[str, str]: return {"openai_api_key": "OPENAI_API_KEY"} @@ -1475,6 +1482,44 @@ class ChatOpenAI(BaseChatOpenAI): """Return whether this model can be serialized by Langchain.""" return True + def _should_stream_usage( + self, stream_usage: Optional[bool] = None, **kwargs: Any + ) -> bool: + """Determine whether to include usage metadata in streaming output. + + For backwards compatibility, we check for `stream_options` passed + explicitly to kwargs or in the model_kwargs and override self.stream_usage. + """ + stream_usage_sources = [ # order of preference + stream_usage, + kwargs.get("stream_options", {}).get("include_usage"), + self.model_kwargs.get("stream_options", {}).get("include_usage"), + self.stream_usage, + ] + for source in stream_usage_sources: + if isinstance(source, bool): + return source + return self.stream_usage + + def _stream( + self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any + ) -> Iterator[ChatGenerationChunk]: + """Set default stream_options.""" + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + kwargs["stream_options"] = {"include_usage": stream_usage} + + return super()._stream(*args, **kwargs) + + async def _astream( + self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any + ) -> AsyncIterator[ChatGenerationChunk]: + """Set default stream_options.""" + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + kwargs["stream_options"] = {"include_usage": stream_usage} + + async for chunk in super()._astream(*args, **kwargs): + yield chunk + def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 10273ea3c8..0fb5bf1ce9 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1,5 +1,5 @@ """Test ChatOpenAI chat model.""" -from typing import Any, List, Optional, cast +from typing import Any, AsyncIterator, List, Optional, cast import pytest from langchain_core.callbacks import CallbackManager @@ -357,7 +357,7 @@ def test_stream() -> None: aggregate: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 chunks_with_response_metadata = 0 - for chunk in llm.stream("Hello", stream_options={"include_usage": True}): + for chunk in llm.stream("Hello", stream_usage=True): assert isinstance(chunk.content, str) aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) @@ -380,39 +380,73 @@ def test_stream() -> None: async def test_astream() -> None: """Test streaming tokens from OpenAI.""" - llm = ChatOpenAI() - - full: Optional[BaseMessageChunk] = None - async for chunk in llm.astream("I'm Pickle Rick"): - assert isinstance(chunk.content, str) - full = chunk if full is None else full + chunk - assert isinstance(full, AIMessageChunk) - assert full.response_metadata.get("finish_reason") is not None - assert full.response_metadata.get("model_name") is not None - # check token usage - aggregate: Optional[BaseMessageChunk] = None - chunks_with_token_counts = 0 - chunks_with_response_metadata = 0 - async for chunk in llm.astream("Hello", stream_options={"include_usage": True}): - assert isinstance(chunk.content, str) - aggregate = chunk if aggregate is None else aggregate + chunk - assert isinstance(chunk, AIMessageChunk) - if chunk.usage_metadata is not None: - chunks_with_token_counts += 1 - if chunk.response_metadata: - chunks_with_response_metadata += 1 - if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: - raise AssertionError( - "Expected exactly one chunk with metadata. " - "AIMessageChunk aggregation can add these metadata. Check that " - "this is behaving properly." - ) - assert isinstance(aggregate, AIMessageChunk) - assert aggregate.usage_metadata is not None - assert aggregate.usage_metadata["input_tokens"] > 0 - assert aggregate.usage_metadata["output_tokens"] > 0 - assert aggregate.usage_metadata["total_tokens"] > 0 + async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: + full: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 + chunks_with_response_metadata = 0 + async for chunk in stream: + assert isinstance(chunk.content, str) + full = chunk if full is None else full + chunk + assert isinstance(chunk, AIMessageChunk) + if chunk.usage_metadata is not None: + chunks_with_token_counts += 1 + if chunk.response_metadata: + chunks_with_response_metadata += 1 + assert isinstance(full, AIMessageChunk) + if chunks_with_response_metadata != 1: + raise AssertionError( + "Expected exactly one chunk with metadata. " + "AIMessageChunk aggregation can add these metadata. Check that " + "this is behaving properly." + ) + assert full.response_metadata.get("finish_reason") is not None + assert full.response_metadata.get("model_name") is not None + if expect_usage: + if chunks_with_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 + else: + assert chunks_with_token_counts == 0 + assert full.usage_metadata is None + + llm = ChatOpenAI(temperature=0, max_tokens=5) + await _test_stream(llm.astream("Hello"), expect_usage=False) + await _test_stream( + llm.astream("Hello", stream_options={"include_usage": True}), + expect_usage=True, + ) + await _test_stream( + llm.astream("Hello", stream_usage=True), + expect_usage=True, + ) + llm = ChatOpenAI( + temperature=0, + max_tokens=5, + model_kwargs={"stream_options": {"include_usage": True}}, + ) + await _test_stream(llm.astream("Hello"), expect_usage=True) + await _test_stream( + llm.astream("Hello", stream_options={"include_usage": False}), + expect_usage=False, + ) + llm = ChatOpenAI( + temperature=0, + max_tokens=5, + stream_usage=True, + ) + await _test_stream(llm.astream("Hello"), expect_usage=True) + await _test_stream( + llm.astream("Hello", stream_usage=False), + expect_usage=False, + ) async def test_abatch() -> None: