mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
openai[patch]: add stream_usage parameter (#22854)
Here we add `stream_usage` to ChatOpenAI as:
1. a boolean attribute
2. a kwarg to _stream and _astream.
Question: should the `stream_usage` attribute be `bool`, or `bool |
None`?
Currently I've kept it `bool` and defaulted to False. It was implemented
on
[ChatAnthropic](e832bbb486/libs/partners/anthropic/langchain_anthropic/chat_models.py (L535)
)
as a bool. However, to maintain support for users who access the
behavior via OpenAI's `stream_options` param, this ends up being
possible:
```python
llm = ChatOpenAI(model_kwargs={"stream_options": {"include_usage": True}})
assert not llm.stream_usage
```
(and this model will stream token usage).
Some options for this:
- it's ok
- make the `stream_usage` attribute bool or None
- make an \_\_init\_\_ for ChatOpenAI, set a `._stream_usage` attribute
and read `.stream_usage` from a property
Open to other ideas as well.
This commit is contained in:
parent
56ac94e014
commit
722c8f50ea
@ -495,6 +495,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
)
|
||||
logprobs = None
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
@ -619,6 +620,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
)
|
||||
logprobs = None
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
@ -1386,11 +1388,11 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
|
||||
|
||||
When streaming, set the ``stream_options`` model kwarg:
|
||||
When streaming, set the ``stream_usage`` kwarg:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
stream = llm.stream(messages, stream_options={"include_usage": True})
|
||||
stream = llm.stream(messages, stream_usage=True)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
@ -1400,7 +1402,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
|
||||
|
||||
Alternatively, setting ``stream_options`` when instantiating the model can be
|
||||
Alternatively, setting ``stream_usage`` when instantiating the model can be
|
||||
useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using
|
||||
methods like ``.with_structured_output``, which generate chains under the
|
||||
hood.
|
||||
@ -1409,7 +1411,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o",
|
||||
model_kwargs={"stream_options": {"include_usage": True}},
|
||||
stream_usage=True,
|
||||
)
|
||||
structured_llm = llm.with_structured_output(...)
|
||||
|
||||
@ -1446,6 +1448,11 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
stream_usage: bool = False
|
||||
"""Whether to include usage metadata in streaming output. If True, additional
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
@ -1475,6 +1482,44 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> bool:
|
||||
"""Determine whether to include usage metadata in streaming output.
|
||||
|
||||
For backwards compatibility, we check for `stream_options` passed
|
||||
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
||||
"""
|
||||
stream_usage_sources = [ # order of preference
|
||||
stream_usage,
|
||||
kwargs.get("stream_options", {}).get("include_usage"),
|
||||
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
||||
self.stream_usage,
|
||||
]
|
||||
for source in stream_usage_sources:
|
||||
if isinstance(source, bool):
|
||||
return source
|
||||
return self.stream_usage
|
||||
|
||||
def _stream(
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Test ChatOpenAI chat model."""
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, AsyncIterator, List, Optional, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
@ -357,7 +357,7 @@ def test_stream() -> None:
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
chunks_with_response_metadata = 0
|
||||
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
|
||||
for chunk in llm.stream("Hello", stream_usage=True):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
@ -380,39 +380,73 @@ def test_stream() -> None:
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
chunks_with_response_metadata = 0
|
||||
async for chunk in stream:
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunk.response_metadata:
|
||||
chunks_with_response_metadata += 1
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
if chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with metadata. "
|
||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
assert full.response_metadata.get("model_name") is not None
|
||||
if expect_usage:
|
||||
if chunks_with_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
else:
|
||||
assert chunks_with_token_counts == 0
|
||||
assert full.usage_metadata is None
|
||||
|
||||
# check token usage
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
chunks_with_response_metadata = 0
|
||||
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunk.response_metadata:
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with metadata. "
|
||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
assert aggregate.usage_metadata is not None
|
||||
assert aggregate.usage_metadata["input_tokens"] > 0
|
||||
assert aggregate.usage_metadata["output_tokens"] > 0
|
||||
assert aggregate.usage_metadata["total_tokens"] > 0
|
||||
llm = ChatOpenAI(temperature=0, max_tokens=5)
|
||||
await _test_stream(llm.astream("Hello"), expect_usage=False)
|
||||
await _test_stream(
|
||||
llm.astream("Hello", stream_options={"include_usage": True}),
|
||||
expect_usage=True,
|
||||
)
|
||||
await _test_stream(
|
||||
llm.astream("Hello", stream_usage=True),
|
||||
expect_usage=True,
|
||||
)
|
||||
llm = ChatOpenAI(
|
||||
temperature=0,
|
||||
max_tokens=5,
|
||||
model_kwargs={"stream_options": {"include_usage": True}},
|
||||
)
|
||||
await _test_stream(llm.astream("Hello"), expect_usage=True)
|
||||
await _test_stream(
|
||||
llm.astream("Hello", stream_options={"include_usage": False}),
|
||||
expect_usage=False,
|
||||
)
|
||||
llm = ChatOpenAI(
|
||||
temperature=0,
|
||||
max_tokens=5,
|
||||
stream_usage=True,
|
||||
)
|
||||
await _test_stream(llm.astream("Hello"), expect_usage=True)
|
||||
await _test_stream(
|
||||
llm.astream("Hello", stream_usage=False),
|
||||
expect_usage=False,
|
||||
)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user