@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any ,
AsyncGenerator ,
AsyncIterator ,
Callable ,
Dict ,
@ -113,6 +114,26 @@ def create_base_retry_decorator(
)
def _as_async_iterator ( sync_iterator : Callable ) - > Callable :
""" Convert a sync iterator into an async iterator. """
async def _as_sync_iterator ( * args : Any , * * kwargs : Any ) - > AsyncGenerator :
iterator = await run_in_executor ( None , sync_iterator , * args , * * kwargs )
done = object ( )
while True :
item = await run_in_executor (
None ,
next ,
iterator ,
done , # type: ignore[call-arg, arg-type]
)
if item is done :
break
yield item # type: ignore[misc]
return _as_sync_iterator
def get_prompts (
params : Dict [ str , Any ] , prompts : List [ str ]
) - > Tuple [ Dict [ int , List ] , str , List [ int ] , List [ str ] ] :
@ -434,54 +455,71 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop : Optional [ List [ str ] ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ str ] :
if type ( self ) . _astream == BaseLLM . _astream :
if type ( self ) . _astream is not BaseLLM . _astream :
# model doesn't implement streaming, so use default implementation
yield await self . ainvoke ( input , config = config , stop = stop , * * kwargs )
el se:
prompt = self . _convert_input ( input ) . to_string ( )
config = ensure_config ( config )
params = self . dict ( )
params[ " stop " ] = stop
params = { * * params , * * kwargs }
options = { " stop " : stop }
callback_manager = AsyncCallbackManager . configure (
config . get ( " callbacks " ) ,
self . callbacks ,
self . verbose ,
config . get ( " tags " ) ,
self . tags ,
config . get ( " metadata " ) ,
self . metadata ,
_stream_implementation = self . _astream
el if type ( self) . _stream is not BaseLLM . _stream :
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast ( # type: ignore
Callable [
[
str ,
Optional [ List [ str ] ] ,
CallbackManagerForLLMRun ,
Any ,
] ,
AsyncIterator [ GenerationChunk ] ,
] ,
_as_async_iterator ( self . _stream ) ,
)
( run_manager , ) = await callback_manager . on_llm_start (
dumpd ( self ) ,
[ prompt ] ,
invocation_params = params ,
options = options ,
name = config . get ( " run_name " ) ,
batch_size = 1 ,
else :
yield await self . ainvoke ( input , config = config , stop = stop , * * kwargs )
return
prompt = self . _convert_input ( input ) . to_string ( )
config = ensure_config ( config )
params = self . dict ( )
params [ " stop " ] = stop
params = { * * params , * * kwargs }
options = { " stop " : stop }
callback_manager = AsyncCallbackManager . configure (
config . get ( " callbacks " ) ,
self . callbacks ,
self . verbose ,
config . get ( " tags " ) ,
self . tags ,
config . get ( " metadata " ) ,
self . metadata ,
)
( run_manager , ) = await callback_manager . on_llm_start (
dumpd ( self ) ,
[ prompt ] ,
invocation_params = params ,
options = options ,
name = config . get ( " run_name " ) ,
batch_size = 1 ,
)
generation : Optional [ GenerationChunk ] = None
try :
async for chunk in _stream_implementation (
prompt , stop = stop , run_manager = run_manager , * * kwargs
) :
yield chunk . text
if generation is None :
generation = chunk
else :
generation + = chunk
assert generation is not None
except BaseException as e :
await run_manager . on_llm_error (
e ,
response = LLMResult ( generations = [ [ generation ] ] if generation else [ ] ) ,
)
generation : Optional [ GenerationChunk ] = None
try :
async for chunk in self . _astream (
prompt , stop = stop , run_manager = run_manager , * * kwargs
) :
yield chunk . text
if generation is None :
generation = chunk
else :
generation + = chunk
assert generation is not None
except BaseException as e :
await run_manager . on_llm_error (
e ,
response = LLMResult (
generations = [ [ generation ] ] if generation else [ ]
) ,
)
raise e
else :
await run_manager . on_llm_end ( LLMResult ( generations = [ [ generation ] ] ) )
raise e
else :
await run_manager . on_llm_end ( LLMResult ( generations = [ [ generation ] ] ) )
# --- Custom methods ---