diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index b6111b5459..f2fd712028 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -12,6 +12,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, + AsyncGenerator, AsyncIterator, Callable, Dict, @@ -113,6 +114,26 @@ def create_base_retry_decorator( ) +def _as_async_iterator(sync_iterator: Callable) -> Callable: + """Convert a sync iterator into an async iterator.""" + + async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator: + iterator = await run_in_executor(None, sync_iterator, *args, **kwargs) + done = object() + while True: + item = await run_in_executor( + None, + next, + iterator, + done, # type: ignore[call-arg, arg-type] + ) + if item is done: + break + yield item # type: ignore[misc] + + return _as_sync_iterator + + def get_prompts( params: Dict[str, Any], prompts: List[str] ) -> Tuple[Dict[int, List], str, List[int], List[str]]: @@ -434,54 +455,71 @@ class BaseLLM(BaseLanguageModel[str], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[str]: - if type(self)._astream == BaseLLM._astream: + if type(self)._astream is not BaseLLM._astream: # model doesn't implement streaming, so use default implementation - yield await self.ainvoke(input, config=config, stop=stop, **kwargs) + _stream_implementation = self._astream + elif type(self)._stream is not BaseLLM._stream: + # Then stream is implemented, so we can create an async iterator from it + # The typing is hard to type correctly with mypy here, so we cast + # and do a type ignore, this code is unit tested and should be fine. + _stream_implementation = cast( # type: ignore + Callable[ + [ + str, + Optional[List[str]], + CallbackManagerForLLMRun, + Any, + ], + AsyncIterator[GenerationChunk], + ], + _as_async_iterator(self._stream), + ) else: - prompt = self._convert_input(input).to_string() - config = ensure_config(config) - params = self.dict() - params["stop"] = stop - params = {**params, **kwargs} - options = {"stop": stop} - callback_manager = AsyncCallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, + yield await self.ainvoke(input, config=config, stop=stop, **kwargs) + return + + prompt = self._convert_input(input).to_string() + config = ensure_config(config) + params = self.dict() + params["stop"] = stop + params = {**params, **kwargs} + options = {"stop": stop} + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = await callback_manager.on_llm_start( + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=config.get("run_name"), + batch_size=1, + ) + generation: Optional[GenerationChunk] = None + try: + async for chunk in _stream_implementation( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except BaseException as e: + await run_manager.on_llm_error( + e, + response=LLMResult(generations=[[generation]] if generation else []), ) - (run_manager,) = await callback_manager.on_llm_start( - dumpd(self), - [prompt], - invocation_params=params, - options=options, - name=config.get("run_name"), - batch_size=1, - ) - generation: Optional[GenerationChunk] = None - try: - async for chunk in self._astream( - prompt, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.text - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - except BaseException as e: - await run_manager.on_llm_error( - e, - response=LLMResult( - generations=[[generation]] if generation else [] - ), - ) - raise e - else: - await run_manager.on_llm_end(LLMResult(generations=[[generation]])) + raise e + else: + await run_manager.on_llm_end(LLMResult(generations=[[generation]])) # --- Custom methods --- diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index a6e866cf97..5d701b384b 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -1,6 +1,13 @@ +from typing import Any, AsyncIterator, Iterator, List, Optional + import pytest -from langchain_core.outputs.llm_result import LLMResult +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.tracers.context import collect_runs from tests.unit_tests.fake.callbacks import ( BaseFakeCallbackHandler, @@ -113,3 +120,100 @@ async def test_stream_error_callback() -> None: pass eval_response(cb_sync, i) + + +async def test_astream_fallback_to_ainvoke() -> None: + """Test astream uses appropriate implementation.""" + + class ModelWithGenerate(BaseLLM): + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + generations = [Generation(text="hello")] + return LLMResult(generations=[generations]) + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + model = ModelWithGenerate() + chunks = [chunk for chunk in model.stream("anything")] + assert chunks == ["hello"] + + chunks = [chunk async for chunk in model.astream("anything")] + assert chunks == ["hello"] + + +async def test_astream_implementation_fallback_to_stream() -> None: + """Test astream uses appropriate implementation.""" + + class ModelWithSyncStream(BaseLLM): + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Top Level call""" + raise NotImplementedError() + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Stream the output of the model.""" + yield GenerationChunk(text="a") + yield GenerationChunk(text="b") + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + model = ModelWithSyncStream() + chunks = [chunk for chunk in model.stream("anything")] + assert chunks == ["a", "b"] + assert type(model)._astream == BaseLLM._astream + astream_chunks = [chunk async for chunk in model.astream("anything")] + assert astream_chunks == ["a", "b"] + + +async def test_astream_implementation_uses_astream() -> None: + """Test astream uses appropriate implementation.""" + + class ModelWithAsyncStream(BaseLLM): + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Top Level call""" + raise NotImplementedError() + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + """Stream the output of the model.""" + yield GenerationChunk(text="a") + yield GenerationChunk(text="b") + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + model = ModelWithAsyncStream() + chunks = [chunk async for chunk in model.astream("anything")] + assert chunks == ["a", "b"]