@ -1,5 +1,5 @@
""" Test ChatOpenAI chat model. """
from typing import Any , List, Optional , cast
from typing import Any , AsyncIterator, List, Optional , cast
import pytest
from langchain_core . callbacks import CallbackManager
@ -357,7 +357,7 @@ def test_stream() -> None:
aggregate : Optional [ BaseMessageChunk ] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
for chunk in llm . stream ( " Hello " , stream_ options= { " include_usage " : True } ) :
for chunk in llm . stream ( " Hello " , stream_ usage= True ) :
assert isinstance ( chunk . content , str )
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance ( chunk , AIMessageChunk )
@ -380,39 +380,73 @@ def test_stream() -> None:
async def test_astream ( ) - > None :
""" Test streaming tokens from OpenAI. """
llm = ChatOpenAI ( )
full : Optional [ BaseMessageChunk ] = None
async for chunk in llm . astream ( " I ' m Pickle Rick " ) :
assert isinstance ( chunk . content , str )
full = chunk if full is None else full + chunk
assert isinstance ( full , AIMessageChunk )
assert full . response_metadata . get ( " finish_reason " ) is not None
assert full . response_metadata . get ( " model_name " ) is not None
# check token usage
aggregate : Optional [ BaseMessageChunk ] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for chunk in llm . astream ( " Hello " , stream_options = { " include_usage " : True } ) :
assert isinstance ( chunk . content , str )
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance ( chunk , AIMessageChunk )
if chunk . usage_metadata is not None :
chunks_with_token_counts + = 1
if chunk . response_metadata :
chunks_with_response_metadata + = 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1 :
raise AssertionError (
" Expected exactly one chunk with metadata. "
" AIMessageChunk aggregation can add these metadata. Check that "
" this is behaving properly. "
)
assert isinstance ( aggregate , AIMessageChunk )
assert aggregate . usage_metadata is not None
assert aggregate . usage_metadata [ " input_tokens " ] > 0
assert aggregate . usage_metadata [ " output_tokens " ] > 0
assert aggregate . usage_metadata [ " total_tokens " ] > 0
async def _test_stream ( stream : AsyncIterator , expect_usage : bool ) - > None :
full : Optional [ BaseMessageChunk ] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for chunk in stream :
assert isinstance ( chunk . content , str )
full = chunk if full is None else full + chunk
assert isinstance ( chunk , AIMessageChunk )
if chunk . usage_metadata is not None :
chunks_with_token_counts + = 1
if chunk . response_metadata :
chunks_with_response_metadata + = 1
assert isinstance ( full , AIMessageChunk )
if chunks_with_response_metadata != 1 :
raise AssertionError (
" Expected exactly one chunk with metadata. "
" AIMessageChunk aggregation can add these metadata. Check that "
" this is behaving properly. "
)
assert full . response_metadata . get ( " finish_reason " ) is not None
assert full . response_metadata . get ( " model_name " ) is not None
if expect_usage :
if chunks_with_token_counts != 1 :
raise AssertionError (
" Expected exactly one chunk with token counts. "
" AIMessageChunk aggregation adds counts. Check that "
" this is behaving properly. "
)
assert full . usage_metadata is not None
assert full . usage_metadata [ " input_tokens " ] > 0
assert full . usage_metadata [ " output_tokens " ] > 0
assert full . usage_metadata [ " total_tokens " ] > 0
else :
assert chunks_with_token_counts == 0
assert full . usage_metadata is None
llm = ChatOpenAI ( temperature = 0 , max_tokens = 5 )
await _test_stream ( llm . astream ( " Hello " ) , expect_usage = False )
await _test_stream (
llm . astream ( " Hello " , stream_options = { " include_usage " : True } ) ,
expect_usage = True ,
)
await _test_stream (
llm . astream ( " Hello " , stream_usage = True ) ,
expect_usage = True ,
)
llm = ChatOpenAI (
temperature = 0 ,
max_tokens = 5 ,
model_kwargs = { " stream_options " : { " include_usage " : True } } ,
)
await _test_stream ( llm . astream ( " Hello " ) , expect_usage = True )
await _test_stream (
llm . astream ( " Hello " , stream_options = { " include_usage " : False } ) ,
expect_usage = False ,
)
llm = ChatOpenAI (
temperature = 0 ,
max_tokens = 5 ,
stream_usage = True ,
)
await _test_stream ( llm . astream ( " Hello " ) , expect_usage = True )
await _test_stream (
llm . astream ( " Hello " , stream_usage = False ) ,
expect_usage = False ,
)
async def test_abatch ( ) - > None :