@ -1,6 +1,13 @@
from typing import Any , AsyncIterator , Iterator , List , Optional
import pytest
from langchain_core . outputs . llm_result import LLMResult
from langchain_core . callbacks import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
from langchain_core . language_models . llms import BaseLLM
from langchain_core . outputs import Generation , GenerationChunk , LLMResult
from langchain_core . tracers . context import collect_runs
from tests . unit_tests . fake . callbacks import (
BaseFakeCallbackHandler ,
@ -113,3 +120,100 @@ async def test_stream_error_callback() -> None:
pass
eval_response ( cb_sync , i )
async def test_astream_fallback_to_ainvoke ( ) - > None :
""" Test astream uses appropriate implementation. """
class ModelWithGenerate ( BaseLLM ) :
def _generate (
self ,
prompts : List [ str ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > LLMResult :
generations = [ Generation ( text = " hello " ) ]
return LLMResult ( generations = [ generations ] )
@property
def _llm_type ( self ) - > str :
return " fake-chat-model "
model = ModelWithGenerate ( )
chunks = [ chunk for chunk in model . stream ( " anything " ) ]
assert chunks == [ " hello " ]
chunks = [ chunk async for chunk in model . astream ( " anything " ) ]
assert chunks == [ " hello " ]
async def test_astream_implementation_fallback_to_stream ( ) - > None :
""" Test astream uses appropriate implementation. """
class ModelWithSyncStream ( BaseLLM ) :
def _generate (
self ,
prompts : List [ str ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > LLMResult :
""" Top Level call """
raise NotImplementedError ( )
def _stream (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ GenerationChunk ] :
""" Stream the output of the model. """
yield GenerationChunk ( text = " a " )
yield GenerationChunk ( text = " b " )
@property
def _llm_type ( self ) - > str :
return " fake-chat-model "
model = ModelWithSyncStream ( )
chunks = [ chunk for chunk in model . stream ( " anything " ) ]
assert chunks == [ " a " , " b " ]
assert type ( model ) . _astream == BaseLLM . _astream
astream_chunks = [ chunk async for chunk in model . astream ( " anything " ) ]
assert astream_chunks == [ " a " , " b " ]
async def test_astream_implementation_uses_astream ( ) - > None :
""" Test astream uses appropriate implementation. """
class ModelWithAsyncStream ( BaseLLM ) :
def _generate (
self ,
prompts : List [ str ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > LLMResult :
""" Top Level call """
raise NotImplementedError ( )
async def _astream (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ GenerationChunk ] :
""" Stream the output of the model. """
yield GenerationChunk ( text = " a " )
yield GenerationChunk ( text = " b " )
@property
def _llm_type ( self ) - > str :
return " fake-chat-model "
model = ModelWithAsyncStream ( )
chunks = [ chunk async for chunk in model . astream ( " anything " ) ]
assert chunks == [ " a " , " b " ]