From 0689628489967785f3a11a9f29d8f6f90930f4f4 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 17 Aug 2023 23:46:23 -0700 Subject: [PATCH] Adds streaming for runnable maps (#9283) @nfcampos @baskaryan --------- Co-authored-by: Nuno Campos --- libs/langchain/langchain/chat_models/fake.py | 7 + libs/langchain/langchain/llms/fake.py | 7 + .../langchain/schema/output_parser.py | 2 + .../langchain/schema/runnable/base.py | 310 ++++++++++++++++-- .../langchain/schema/runnable/passthrough.py | 14 +- .../langchain/schema/runnable/utils.py | 17 +- libs/langchain/langchain/utils/aiter.py | 49 ++- libs/langchain/langchain/utils/iter.py | 162 +++++++++ .../tests/unit_tests/schema/test_runnable.py | 184 ++++++++++- 9 files changed, 704 insertions(+), 48 deletions(-) create mode 100644 libs/langchain/langchain/utils/iter.py diff --git a/libs/langchain/langchain/chat_models/fake.py b/libs/langchain/langchain/chat_models/fake.py index fe332016e6..1fe54fef64 100644 --- a/libs/langchain/langchain/chat_models/fake.py +++ b/libs/langchain/langchain/chat_models/fake.py @@ -1,4 +1,6 @@ """Fake ChatModel for testing purposes.""" +import asyncio +import time from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from langchain.callbacks.manager import ( @@ -14,6 +16,7 @@ class FakeListChatModel(SimpleChatModel): """Fake ChatModel for testing purposes.""" responses: List + sleep: Optional[float] = None i: int = 0 @property @@ -48,6 +51,8 @@ class FakeListChatModel(SimpleChatModel): else: self.i = 0 for c in response: + if self.sleep is not None: + time.sleep(self.sleep) yield ChatGenerationChunk(message=AIMessageChunk(content=c)) async def _astream( @@ -63,6 +68,8 @@ class FakeListChatModel(SimpleChatModel): else: self.i = 0 for c in response: + if self.sleep is not None: + await asyncio.sleep(self.sleep) yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property diff --git a/libs/langchain/langchain/llms/fake.py b/libs/langchain/langchain/llms/fake.py index d8a5a7fd03..b3c22eb644 100644 --- a/libs/langchain/langchain/llms/fake.py +++ b/libs/langchain/langchain/llms/fake.py @@ -1,3 +1,5 @@ +import asyncio +import time from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional from langchain.callbacks.manager import ( @@ -13,6 +15,7 @@ class FakeListLLM(LLM): """Fake LLM for testing purposes.""" responses: List + sleep: Optional[float] = None i: int = 0 @property @@ -68,6 +71,8 @@ class FakeStreamingListLLM(FakeListLLM): ) -> Iterator[str]: result = self.invoke(input, config) for c in result: + if self.sleep is not None: + time.sleep(self.sleep) yield c async def astream( @@ -80,4 +85,6 @@ class FakeStreamingListLLM(FakeListLLM): ) -> AsyncIterator[str]: result = await self.ainvoke(input, config) for c in result: + if self.sleep is not None: + await asyncio.sleep(self.sleep) yield c diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 9290c4cede..8de216800f 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -277,6 +277,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): self, input: Iterator[Union[str, BaseMessage]], config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Iterator[T]: yield from self._transform_stream_with_config( input, self._transform, config, run_type="parser" @@ -286,6 +287,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): self, input: AsyncIterator[Union[str, BaseMessage]], config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> AsyncIterator[T]: async for chunk in self._atransform_stream_with_config( input, self._atransform, config, run_type="parser" diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c9d836cfd8..1119189bc8 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1,10 +1,13 @@ from __future__ import annotations import asyncio +import copy +import threading from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from itertools import tee from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -23,15 +26,25 @@ from typing import ( cast, ) +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + + from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.utils import ( + accepts_run_manager, + accepts_run_manager_and_config, gather_with_concurrency, ) from langchain.utils.aiter import atee, py_anext +from langchain.utils.iter import safetee Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do @@ -48,7 +61,7 @@ class Runnable(Generic[Input, Output], ABC): other: Union[ Runnable[Any, Other], Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], + Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: return RunnableSequence(first=self, last=coerce_to_runnable(other)) @@ -58,7 +71,7 @@ class Runnable(Generic[Input, Output], ABC): other: Union[ Runnable[Other, Any], Callable[[Any], Other], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], + Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: return RunnableSequence(first=coerce_to_runnable(other), last=self) @@ -135,7 +148,10 @@ class Runnable(Generic[Input, Output], ABC): yield await self.ainvoke(input, config) def transform( - self, input: Iterator[Input], config: Optional[RunnableConfig] = None + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: """ Default implementation of transform, which buffers input and then calls stream. @@ -152,10 +168,13 @@ class Runnable(Generic[Input, Output], ABC): # This method should throw an error if gathering fails. final += chunk # type: ignore[operator] if final: - yield from self.stream(final, config) + yield from self.stream(final, config, **kwargs) async def atransform( - self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: """ Default implementation of atransform, which buffers input and calls astream. @@ -173,7 +192,7 @@ class Runnable(Generic[Input, Output], ABC): final += chunk # type: ignore[operator] if final: - async for output in self.astream(final, config): + async for output in self.astream(final, config, **kwargs): yield output def bind(self, **kwargs: Any) -> Runnable[Input, Output]: @@ -217,7 +236,11 @@ class Runnable(Generic[Input, Output], ABC): def _call_with_config( self, - func: Callable[[Input], Output], + func: Union[ + Callable[[Input], Output], + Callable[[Input, CallbackManagerForChainRun], Output], + Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], + ], input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, @@ -238,7 +261,16 @@ class Runnable(Generic[Input, Output], ABC): run_type=run_type, ) try: - output = func(input) + if accepts_run_manager_and_config(func): + output = func( + input, + run_manager=run_manager, + config=config, + ) # type: ignore[call-arg] + elif accepts_run_manager(func): + output = func(input, run_manager=run_manager) # type: ignore[call-arg] + else: + output = func(input) # type: ignore[call-arg] except Exception as e: run_manager.on_chain_error(e) raise @@ -253,7 +285,14 @@ class Runnable(Generic[Input, Output], ABC): async def _acall_with_config( self, - func: Callable[[Input], Awaitable[Output]], + func: Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ], input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, @@ -274,7 +313,19 @@ class Runnable(Generic[Input, Output], ABC): run_type=run_type, ) try: - output = await func(input) + if accepts_run_manager_and_config(func): + output = await func( + input, + run_manager=run_manager, + config=config, + ) # type: ignore[call-arg] + elif accepts_run_manager(func): + output = await func( + input, + run_manager=run_manager, + ) # type: ignore[call-arg] + else: + output = await func(input) # type: ignore[call-arg] except Exception as e: await run_manager.on_chain_error(e) raise @@ -290,7 +341,18 @@ class Runnable(Generic[Input, Output], ABC): def _transform_stream_with_config( self, input: Iterator[Input], - transformer: Callable[[Iterator[Input]], Iterator[Output]], + transformer: Union[ + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]], + Callable[ + [ + Iterator[Input], + CallbackManagerForChainRun, + RunnableConfig, + ], + Iterator[Output], + ], + ], config: Optional[RunnableConfig], run_type: Optional[str] = None, ) -> Iterator[Output]: @@ -319,7 +381,20 @@ class Runnable(Generic[Input, Output], ABC): run_type=run_type, ) try: - for chunk in transformer(input_for_transform): + if accepts_run_manager_and_config(transformer): + iterator = transformer( + input_for_transform, + run_manager=run_manager, + config=config, + ) # type: ignore[call-arg] + elif accepts_run_manager(transformer): + iterator = transformer( + input_for_transform, + run_manager=run_manager, + ) # type: ignore[call-arg] + else: + iterator = transformer(input_for_transform) # type: ignore[call-arg] + for chunk in iterator: yield chunk if final_output_supported: if final_output is None: @@ -361,7 +436,21 @@ class Runnable(Generic[Input, Output], ABC): async def _atransform_stream_with_config( self, input: AsyncIterator[Input], - transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + transformer: Union[ + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + Callable[ + [AsyncIterator[Input], AsyncCallbackManagerForChainRun], + AsyncIterator[Output], + ], + Callable[ + [ + AsyncIterator[Input], + AsyncCallbackManagerForChainRun, + RunnableConfig, + ], + AsyncIterator[Output], + ], + ], config: Optional[RunnableConfig], run_type: Optional[str] = None, ) -> AsyncIterator[Output]: @@ -390,7 +479,22 @@ class Runnable(Generic[Input, Output], ABC): run_type=run_type, ) try: - async for chunk in transformer(input_for_transform): + # mypy can't quite work out thew type guard here, but this is safe, + # check implementations of the accepts_* functions + if accepts_run_manager_and_config(transformer): + iterator = transformer( + input_for_transform, + run_manager=run_manager, + config=config, + ) # type: ignore[call-arg] + elif accepts_run_manager(transformer): + iterator = transformer( + input_for_transform, + run_manager=run_manager, + ) # type: ignore[call-arg] + else: + iterator = transformer(input_for_transform) # type: ignore[call-arg] + async for chunk in iterator: yield chunk if final_output_supported: if final_output is None: @@ -700,7 +804,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): other: Union[ Runnable[Any, Other], Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], + Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: if isinstance(other, RunnableSequence): @@ -721,7 +825,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): other: Union[ Runnable[Other, Any], Callable[[Any], Other], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], + Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: if isinstance(other, RunnableSequence): @@ -875,7 +979,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) -> List[Output]: from langchain.callbacks.manager import ( AsyncCallbackManager, - AsyncCallbackManagerForChainRun, ) # setup callbacks @@ -1085,6 +1188,21 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) +class RunnableMapChunk(Dict[str, Any]): + """ + Partial output from a RunnableMap + """ + + def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk: + chunk = copy.deepcopy(self) + for key in other: + if key not in chunk or chunk[key] is None: + chunk[key] = other[key] + elif other[key] is not None: + chunk[key] += other[key] + return chunk + + class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): """ A runnable that runs a mapping of runnables in parallel, @@ -1134,7 +1252,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): local_metadata=None, ) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input}) + run_manager = callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) # gather results from all steps try: @@ -1177,7 +1297,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): ) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), {"input": input} + dumpd(self), input if isinstance(input, dict) else {"input": input} ) # gather results from all steps @@ -1203,6 +1323,134 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): await run_manager.on_chain_end(output) return output + def _transform( + self, + input: Iterator[Input], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Iterator[RunnableMapChunk]: + # Shallow copy steps to ignore mutations while in progress + steps = dict(self.steps) + # Each step gets a copy of the input iterator, + # which is consumed in parallel in a separate thread. + input_copies = list(safetee(input, len(steps), lock=threading.Lock())) + with ThreadPoolExecutor() as executor: + # Create the transform() generator for each step + named_generators = [ + ( + name, + step.transform( + input_copies.pop(), + patch_config(config, run_manager.get_child()), + ), + ) + for name, step in steps.items() + ] + # Start the first iteration of each generator + futures = { + executor.submit(next, generator): (step_name, generator) + for step_name, generator in named_generators + } + # Yield chunks from each as they become available, + # and start the next iteration of that generator that yielded it. + # When all generators are exhausted, stop. + while futures: + completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED) + for future in completed_futures: + (step_name, generator) = futures.pop(future) + try: + chunk = RunnableMapChunk({step_name: future.result()}) + yield chunk + futures[executor.submit(next, generator)] = ( + step_name, + generator, + ) + except StopIteration: + pass + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + yield from self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Iterator[Dict[str, Any]]: + yield from self.transform(iter([input]), config) + + async def _atransform( + self, + input: AsyncIterator[Input], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> AsyncIterator[RunnableMapChunk]: + # Shallow copy steps to ignore mutations while in progress + steps = dict(self.steps) + # Each step gets a copy of the input iterator, + # which is consumed in parallel in a separate thread. + input_copies = list(atee(input, len(steps), lock=asyncio.Lock())) + # Create the transform() generator for each step + named_generators = [ + ( + name, + step.atransform( + input_copies.pop(), patch_config(config, run_manager.get_child()) + ), + ) + for name, step in steps.items() + ] + + # Wrap in a coroutine to satisfy linter + async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]: + return await py_anext(generator) + + # Start the first iteration of each generator + tasks = { + asyncio.create_task(get_next_chunk(generator)): (step_name, generator) + for step_name, generator in named_generators + } + # Yield chunks from each as they become available, + # and start the next iteration of the generator that yielded it. + # When all generators are exhausted, stop. + while tasks: + completed_tasks, _ = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in completed_tasks: + (step_name, generator) = tasks.pop(task) + try: + chunk = RunnableMapChunk({step_name: task.result()}) + yield chunk + new_task = asyncio.create_task(get_next_chunk(generator)) + tasks[new_task] = (step_name, generator) + except StopAsyncIteration: + pass + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ): + yield chunk + + async def astream( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> AsyncIterator[Dict[str, Any]]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + async for chunk in self.atransform(input_aiter(), config): + yield chunk + class RunnableLambda(Runnable[Input, Output]): """ @@ -1293,14 +1541,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): yield item def transform( - self, input: Iterator[Input], config: Optional[RunnableConfig] = None + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Iterator[Output]: - yield from self.bound.transform(input, config, **self.kwargs) + yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs}) async def atransform( - self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> AsyncIterator[Output]: - async for item in self.bound.atransform(input, config, **self.kwargs): + async for item in self.bound.atransform( + input, config, **{**self.kwargs, **kwargs} + ): yield item @@ -1316,7 +1572,7 @@ def coerce_to_runnable( thing: Union[ Runnable[Input, Output], Callable[[Input], Output], - Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]], + Mapping[str, Any], ] ) -> Runnable[Input, Output]: if isinstance(thing, Runnable): @@ -1324,7 +1580,9 @@ def coerce_to_runnable( elif callable(thing): return RunnableLambda(thing) elif isinstance(thing, dict): - runnables = {key: coerce_to_runnable(r) for key, r in thing.items()} + runnables: Mapping[str, Runnable[Any, Any]] = { + key: coerce_to_runnable(r) for key, r in thing.items() + } return cast(Runnable[Input, Output], RunnableMap(steps=runnables)) else: raise TypeError( diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index a97e708b64..9ff26589ab 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import AsyncIterator, Iterator, List, Optional +from typing import Any, AsyncIterator, Iterator, List, Optional from langchain.load.serializable import Serializable from langchain.schema.runnable.base import Input, Runnable @@ -32,16 +32,22 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): return self._call_with_config(identity, input, config) async def ainvoke( - self, input: Input, config: RunnableConfig | None = None + self, input: Input, config: Optional[RunnableConfig] = None ) -> Input: return await self._acall_with_config(aidentity, input, config) def transform( - self, input: Iterator[Input], config: RunnableConfig | None = None + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Iterator[Input]: return self._transform_stream_with_config(input, identity, config) def atransform( - self, input: AsyncIterator[Input], config: RunnableConfig | None = None + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> AsyncIterator[Input]: return self._atransform_stream_with_config(input, identity, config) diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index 0602413bd6..2afa3705c4 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -1,7 +1,8 @@ from __future__ import annotations import asyncio -from typing import Any, Coroutine, Union +from inspect import signature +from typing import Any, Callable, Coroutine, Union async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: @@ -16,3 +17,17 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis semaphore = asyncio.Semaphore(n) return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros)) + + +def accepts_run_manager(callable: Callable[..., Any]) -> bool: + try: + return signature(callable).parameters.get("run_manager") is not None + except ValueError: + return False + + +def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool: + return ( + accepts_run_manager(callable) + and signature(callable).parameters.get("config") is not None + ) diff --git a/libs/langchain/langchain/utils/aiter.py b/libs/langchain/langchain/utils/aiter.py index 8670c823f9..a71650cdb4 100644 --- a/libs/langchain/langchain/utils/aiter.py +++ b/libs/langchain/langchain/utils/aiter.py @@ -7,6 +7,7 @@ MIT License from collections import deque from typing import ( Any, + AsyncContextManager, AsyncGenerator, AsyncIterator, Awaitable, @@ -15,6 +16,7 @@ from typing import ( Generic, Iterator, List, + Optional, Tuple, TypeVar, Union, @@ -64,33 +66,45 @@ def py_anext( return anext_impl() +class NoLock: + """Dummy lock that provides the proper interface but no protection""" + + async def __aenter__(self) -> None: + pass + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + return False + + async def tee_peer( iterator: AsyncIterator[T], # the buffer specific to this peer buffer: Deque[T], # the buffers of all peers, including our own peers: List[Deque[T]], + lock: AsyncContextManager[Any], ) -> AsyncGenerator[T, None]: """An individual iterator of a :py:func:`~.tee`""" try: while True: if not buffer: - # Another peer produced an item while we were waiting for the lock. - # Proceed with the next loop iteration to yield the item. - if buffer: - continue - try: - item = await iterator.__anext__() - except StopAsyncIteration: - break - else: - # Append to all buffers, including our own. We'll fetch our - # item from the buffer again, instead of yielding it directly. - # This ensures the proper item ordering if any of our peers - # are fetching items concurrently. They may have buffered their - # item already. - for peer_buffer in peers: - peer_buffer.append(item) + async with lock: + # Another peer produced an item while we were waiting for the lock. + # Proceed with the next loop iteration to yield the item. + if buffer: + continue + try: + item = await iterator.__anext__() + except StopAsyncIteration: + break + else: + # Append to all buffers, including our own. We'll fetch our + # item from the buffer again, instead of yielding it directly. + # This ensures the proper item ordering if any of our peers + # are fetching items concurrently. They may have buffered their + # item already. + for peer_buffer in peers: + peer_buffer.append(item) yield buffer.popleft() finally: # this peer is done – remove its buffer @@ -145,6 +159,8 @@ class Tee(Generic[T]): self, iterable: AsyncIterator[T], n: int = 2, + *, + lock: Optional[AsyncContextManager[Any]] = None, ): self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist self._buffers: List[Deque[T]] = [deque() for _ in range(n)] @@ -153,6 +169,7 @@ class Tee(Generic[T]): iterator=self._iterator, buffer=buffer, peers=self._buffers, + lock=lock if lock is not None else NoLock(), ) for buffer in self._buffers ) diff --git a/libs/langchain/langchain/utils/iter.py b/libs/langchain/langchain/utils/iter.py new file mode 100644 index 0000000000..498ccfbf73 --- /dev/null +++ b/libs/langchain/langchain/utils/iter.py @@ -0,0 +1,162 @@ +from collections import deque +from typing import ( + Any, + ContextManager, + Deque, + Generator, + Generic, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) + +from typing_extensions import Literal + +T = TypeVar("T") + + +class NoLock: + """Dummy lock that provides the proper interface but no protection""" + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + return False + + +def tee_peer( + iterator: Iterator[T], + # the buffer specific to this peer + buffer: Deque[T], + # the buffers of all peers, including our own + peers: List[Deque[T]], + lock: ContextManager[Any], +) -> Generator[T, None, None]: + """An individual iterator of a :py:func:`~.tee`""" + try: + while True: + if not buffer: + with lock: + # Another peer produced an item while we were waiting for the lock. + # Proceed with the next loop iteration to yield the item. + if buffer: + continue + try: + item = next(iterator) + except StopIteration: + break + else: + # Append to all buffers, including our own. We'll fetch our + # item from the buffer again, instead of yielding it directly. + # This ensures the proper item ordering if any of our peers + # are fetching items concurrently. They may have buffered their + # item already. + for peer_buffer in peers: + peer_buffer.append(item) + yield buffer.popleft() + finally: + # this peer is done – remove its buffer + for idx, peer_buffer in enumerate(peers): # pragma: no branch + if peer_buffer is buffer: + peers.pop(idx) + break + # if we are the last peer, try and close the iterator + if not peers and hasattr(iterator, "close"): + iterator.close() + + +class Tee(Generic[T]): + """ + Create ``n`` separate asynchronous iterators over ``iterable`` + + This splits a single ``iterable`` into multiple iterators, each providing + the same items in the same order. + All child iterators may advance separately but share the same items + from ``iterable`` -- when the most advanced iterator retrieves an item, + it is buffered until the least advanced iterator has yielded it as well. + A ``tee`` works lazily and can handle an infinite ``iterable``, provided + that all iterators advance. + + .. code-block:: python3 + + async def derivative(sensor_data): + previous, current = a.tee(sensor_data, n=2) + await a.anext(previous) # advance one iterator + return a.map(operator.sub, previous, current) + + Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead + of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked + to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method + immediately closes all children, and it can be used in an ``async with`` context + for the same effect. + + If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* + provide these items. Also, ``tee`` must internally buffer each item until the + last iterator has yielded it; if the most and least advanced iterator differ + by most data, using a :py:class:`list` is more efficient (but not lazy). + + If the underlying iterable is concurrency safe (``anext`` may be awaited + concurrently) the resulting iterators are concurrency safe as well. Otherwise, + the iterators are safe if there is only ever one single "most advanced" iterator. + To enforce sequential use of ``anext``, provide a ``lock`` + - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - + and access is automatically synchronised. + """ + + def __init__( + self, + iterable: Iterator[T], + n: int = 2, + *, + lock: Optional[ContextManager[Any]] = None, + ): + self._iterator = iter(iterable) + self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + self._children = tuple( + tee_peer( + iterator=self._iterator, + buffer=buffer, + peers=self._buffers, + lock=lock if lock is not None else NoLock(), + ) + for buffer in self._buffers + ) + + def __len__(self) -> int: + return len(self._children) + + @overload + def __getitem__(self, item: int) -> Iterator[T]: + ... + + @overload + def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: + ... + + def __getitem__( + self, item: Union[int, slice] + ) -> Union[Iterator[T], Tuple[Iterator[T], ...]]: + return self._children[item] + + def __iter__(self) -> Iterator[Iterator[T]]: + yield from self._children + + def __enter__(self) -> "Tee[T]": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + self.close() + return False + + def close(self) -> None: + for child in self._children: + child.close() + + +# Why this is needed https://stackoverflow.com/a/44638570 +safetee = Tee diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py index c0cae4d9bd..d139edf300 100644 --- a/libs/langchain/tests/unit_tests/schema/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -528,7 +528,7 @@ Question: parser = CommaSeparatedListOutputParser() - chain = ( + chain: Runnable = ( { "question": RunnablePassthrough[str]() | passthrough, "documents": passthrough | retriever, @@ -760,6 +760,188 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert len(map_run.child_runs) == 3 +def test_map_stream() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + + chat_res = "i'm a chatbot" + # sleep to better simulate a real stream + chat = FakeListChatModel(responses=[chat_res], sleep=0.01) + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + chain: Runnable = prompt | { + "chat": chat.bind(stop=["Thought:"]), + "llm": llm, + "passthrough": RunnablePassthrough(), + } + + stream = chain.stream({"question": "What is your name?"}) + + final_value = None + streamed_chunks = [] + for chunk in stream: + streamed_chunks.append(chunk) + if final_value is None: + final_value = chunk + else: + final_value += chunk + + assert streamed_chunks[0] in [ + {"passthrough": prompt.invoke({"question": "What is your name?"})}, + {"llm": "i"}, + {"chat": "i"}, + ] + assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 + assert all(len(c.keys()) == 1 for c in streamed_chunks) + assert final_value is not None + assert final_value.get("chat").content == "i'm a chatbot" + assert final_value.get("llm") == "i'm a textbot" + assert final_value.get("passthrough") == prompt.invoke( + {"question": "What is your name?"} + ) + + +def test_map_stream_iterator_input() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + + chat_res = "i'm a chatbot" + # sleep to better simulate a real stream + chat = FakeListChatModel(responses=[chat_res], sleep=0.01) + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + chain: Runnable = ( + prompt + | llm + | { + "chat": chat.bind(stop=["Thought:"]), + "llm": llm, + "passthrough": RunnablePassthrough(), + } + ) + + stream = chain.stream({"question": "What is your name?"}) + + final_value = None + streamed_chunks = [] + for chunk in stream: + streamed_chunks.append(chunk) + if final_value is None: + final_value = chunk + else: + final_value += chunk + + assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}] + assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res) + assert all(len(c.keys()) == 1 for c in streamed_chunks) + assert final_value is not None + assert final_value.get("chat").content == "i'm a chatbot" + assert final_value.get("llm") == "i'm a textbot" + assert final_value.get("passthrough") == "i'm a textbot" + + +@pytest.mark.asyncio +async def test_map_astream() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + + chat_res = "i'm a chatbot" + # sleep to better simulate a real stream + chat = FakeListChatModel(responses=[chat_res], sleep=0.01) + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + chain: Runnable = prompt | { + "chat": chat.bind(stop=["Thought:"]), + "llm": llm, + "passthrough": RunnablePassthrough(), + } + + stream = chain.astream({"question": "What is your name?"}) + + final_value = None + streamed_chunks = [] + async for chunk in stream: + streamed_chunks.append(chunk) + if final_value is None: + final_value = chunk + else: + final_value += chunk + + assert streamed_chunks[0] in [ + {"passthrough": prompt.invoke({"question": "What is your name?"})}, + {"llm": "i"}, + {"chat": "i"}, + ] + assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 + assert all(len(c.keys()) == 1 for c in streamed_chunks) + assert final_value is not None + assert final_value.get("chat").content == "i'm a chatbot" + assert final_value.get("llm") == "i'm a textbot" + assert final_value.get("passthrough") == prompt.invoke( + {"question": "What is your name?"} + ) + + +@pytest.mark.asyncio +async def test_map_astream_iterator_input() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + + chat_res = "i'm a chatbot" + # sleep to better simulate a real stream + chat = FakeListChatModel(responses=[chat_res], sleep=0.01) + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + chain: Runnable = ( + prompt + | llm + | { + "chat": chat.bind(stop=["Thought:"]), + "llm": llm, + "passthrough": RunnablePassthrough(), + } + ) + + stream = chain.astream({"question": "What is your name?"}) + + final_value = None + streamed_chunks = [] + async for chunk in stream: + streamed_chunks.append(chunk) + if final_value is None: + final_value = chunk + else: + final_value += chunk + + assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}] + assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res) + assert all(len(c.keys()) == 1 for c in streamed_chunks) + assert final_value is not None + assert final_value.get("chat").content == "i'm a chatbot" + assert final_value.get("llm") == "i'm a textbot" + assert final_value.get("passthrough") == llm_res + + def test_bind_bind() -> None: llm = FakeListLLM(responses=["i'm a textbot"])