openai[patch]: add stream_usage parameter (#22854)

Here we add `stream_usage` to ChatOpenAI as:

1. a boolean attribute
2. a kwarg to _stream and _astream.

Question: should the `stream_usage` attribute be `bool`, or `bool |
None`?

Currently I've kept it `bool` and defaulted to False. It was implemented
on
[ChatAnthropic](e832bbb486/libs/partners/anthropic/langchain_anthropic/chat_models.py (L535))
as a bool. However, to maintain support for users who access the
behavior via OpenAI's `stream_options` param, this ends up being
possible:
```python
llm = ChatOpenAI(model_kwargs={"stream_options": {"include_usage": True}})
assert not llm.stream_usage
```
(and this model will stream token usage).

Some options for this:
- it's ok
- make the `stream_usage` attribute bool or None
- make an \_\_init\_\_ for ChatOpenAI, set a `._stream_usage` attribute
and read `.stream_usage` from a property

Open to other ideas as well.
This commit is contained in:
ccurme 2024-06-17 13:35:18 -04:00 committed by GitHub
parent 56ac94e014
commit 722c8f50ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 116 additions and 37 deletions

View File

@ -495,6 +495,7 @@ class BaseChatOpenAI(BaseChatModel):
content="", usage_metadata=usage_metadata
)
)
logprobs = None
else:
continue
else:
@ -619,6 +620,7 @@ class BaseChatOpenAI(BaseChatModel):
content="", usage_metadata=usage_metadata
)
)
logprobs = None
else:
continue
else:
@ -1386,11 +1388,11 @@ class ChatOpenAI(BaseChatOpenAI):
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
When streaming, set the ``stream_options`` model kwarg:
When streaming, set the ``stream_usage`` kwarg:
.. code-block:: python
stream = llm.stream(messages, stream_options={"include_usage": True})
stream = llm.stream(messages, stream_usage=True)
full = next(stream)
for chunk in stream:
full += chunk
@ -1400,7 +1402,7 @@ class ChatOpenAI(BaseChatOpenAI):
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
Alternatively, setting ``stream_options`` when instantiating the model can be
Alternatively, setting ``stream_usage`` when instantiating the model can be
useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using
methods like ``.with_structured_output``, which generate chains under the
hood.
@ -1409,7 +1411,7 @@ class ChatOpenAI(BaseChatOpenAI):
llm = ChatOpenAI(
model="gpt-4o",
model_kwargs={"stream_options": {"include_usage": True}},
stream_usage=True,
)
structured_llm = llm.with_structured_output(...)
@ -1446,6 +1448,11 @@ class ChatOpenAI(BaseChatOpenAI):
""" # noqa: E501
stream_usage: bool = False
"""Whether to include usage metadata in streaming output. If True, additional
message chunks will be generated during the stream including usage metadata.
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@ -1475,6 +1482,44 @@ class ChatOpenAI(BaseChatOpenAI):
"""Return whether this model can be serialized by Langchain."""
return True
def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool:
"""Determine whether to include usage metadata in streaming output.
For backwards compatibility, we check for `stream_options` passed
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
"""
stream_usage_sources = [ # order of preference
stream_usage,
kwargs.get("stream_options", {}).get("include_usage"),
self.model_kwargs.get("stream_options", {}).get("include_usage"),
self.stream_usage,
]
for source in stream_usage_sources:
if isinstance(source, bool):
return source
return self.stream_usage
def _stream(
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
kwargs["stream_options"] = {"include_usage": stream_usage}
return super()._stream(*args, **kwargs)
async def _astream(
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
kwargs["stream_options"] = {"include_usage": stream_usage}
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)

View File

@ -1,5 +1,5 @@
"""Test ChatOpenAI chat model."""
from typing import Any, List, Optional, cast
from typing import Any, AsyncIterator, List, Optional, cast
import pytest
from langchain_core.callbacks import CallbackManager
@ -357,7 +357,7 @@ def test_stream() -> None:
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
for chunk in llm.stream("Hello", stream_usage=True):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk)
@ -380,39 +380,73 @@ def test_stream() -> None:
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for chunk in stream:
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunk.response_metadata:
chunks_with_response_metadata += 1
assert isinstance(full, AIMessageChunk)
if chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
if expect_usage:
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
else:
assert chunks_with_token_counts == 0
assert full.usage_metadata is None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunk.response_metadata:
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is not None
assert aggregate.usage_metadata["input_tokens"] > 0
assert aggregate.usage_metadata["output_tokens"] > 0
assert aggregate.usage_metadata["total_tokens"] > 0
llm = ChatOpenAI(temperature=0, max_tokens=5)
await _test_stream(llm.astream("Hello"), expect_usage=False)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": True}),
expect_usage=True,
)
await _test_stream(
llm.astream("Hello", stream_usage=True),
expect_usage=True,
)
llm = ChatOpenAI(
temperature=0,
max_tokens=5,
model_kwargs={"stream_options": {"include_usage": True}},
)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": False}),
expect_usage=False,
)
llm = ChatOpenAI(
temperature=0,
max_tokens=5,
stream_usage=True,
)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(
llm.astream("Hello", stream_usage=False),
expect_usage=False,
)
async def test_abatch() -> None: