core[minor]: allow LLMs async streaming to fallback on sync streaming (#18960)

- **Description:** Handling fallbacks when calling async streaming for a
LLM that doesn't support it.
- **Issue:** #18920 
- **Twitter handle:**@maximeperrin_

---------

Co-authored-by: Maxime Perrin <mperrin@doing.fr>
This commit is contained in:
Maxime Perrin 2024-03-15 21:06:50 +01:00 committed by GitHub
parent caf47ab666
commit aa785fa6ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 188 additions and 46 deletions

View File

@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
@ -113,6 +114,26 @@ def create_base_retry_decorator(
)
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
def get_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
@ -434,10 +455,29 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream:
if type(self)._astream is not BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
_stream_implementation = self._astream
elif type(self)._stream is not BaseLLM._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
str,
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[GenerationChunk],
],
_as_async_iterator(self._stream),
)
else:
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return
prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
@ -463,7 +503,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in self._astream(
async for chunk in _stream_implementation(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
@ -475,9 +515,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
response=LLMResult(generations=[[generation]] if generation else []),
)
raise e
else:

View File

@ -1,6 +1,13 @@
from typing import Any, AsyncIterator, Iterator, List, Optional
import pytest
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
@ -113,3 +120,100 @@ async def test_stream_error_callback() -> None:
pass
eval_response(cb_sync, i)
async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = [Generation(text="hello")]
return LLMResult(generations=[generations])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["hello"]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["hello"]
async def test_astream_implementation_fallback_to_stream() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithSyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["a", "b"]
assert type(model)._astream == BaseLLM._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == ["a", "b"]
async def test_astream_implementation_uses_astream() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithAsyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["a", "b"]