Adds streaming for runnable maps (#9283)

@nfcampos @baskaryan

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/9437/head
Jacob Lee 1 year ago committed by GitHub
parent 0dd2c21089
commit 0689628489
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,6 @@
"""Fake ChatModel for testing purposes."""
import asyncio
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import (
@ -14,6 +16,7 @@ class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
responses: List
sleep: Optional[float] = None
i: int = 0
@property
@ -48,6 +51,8 @@ class FakeListChatModel(SimpleChatModel):
else:
self.i = 0
for c in response:
if self.sleep is not None:
time.sleep(self.sleep)
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
async def _astream(
@ -63,6 +68,8 @@ class FakeListChatModel(SimpleChatModel):
else:
self.i = 0
for c in response:
if self.sleep is not None:
await asyncio.sleep(self.sleep)
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property

@ -1,3 +1,5 @@
import asyncio
import time
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
from langchain.callbacks.manager import (
@ -13,6 +15,7 @@ class FakeListLLM(LLM):
"""Fake LLM for testing purposes."""
responses: List
sleep: Optional[float] = None
i: int = 0
@property
@ -68,6 +71,8 @@ class FakeStreamingListLLM(FakeListLLM):
) -> Iterator[str]:
result = self.invoke(input, config)
for c in result:
if self.sleep is not None:
time.sleep(self.sleep)
yield c
async def astream(
@ -80,4 +85,6 @@ class FakeStreamingListLLM(FakeListLLM):
) -> AsyncIterator[str]:
result = await self.ainvoke(input, config)
for c in result:
if self.sleep is not None:
await asyncio.sleep(self.sleep)
yield c

@ -277,6 +277,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
@ -286,6 +287,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"

@ -1,10 +1,13 @@
from __future__ import annotations
import asyncio
import copy
import threading
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from itertools import tee
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
@ -23,15 +26,25 @@ from typing import (
cast,
)
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.utils import (
accepts_run_manager,
accepts_run_manager_and_config,
gather_with_concurrency,
)
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
@ -48,7 +61,7 @@ class Runnable(Generic[Input, Output], ABC):
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=coerce_to_runnable(other))
@ -58,7 +71,7 @@ class Runnable(Generic[Input, Output], ABC):
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=coerce_to_runnable(other), last=self)
@ -135,7 +148,10 @@ class Runnable(Generic[Input, Output], ABC):
yield await self.ainvoke(input, config)
def transform(
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""
Default implementation of transform, which buffers input and then calls stream.
@ -152,10 +168,13 @@ class Runnable(Generic[Input, Output], ABC):
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
if final:
yield from self.stream(final, config)
yield from self.stream(final, config, **kwargs)
async def atransform(
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""
Default implementation of atransform, which buffers input and calls astream.
@ -173,7 +192,7 @@ class Runnable(Generic[Input, Output], ABC):
final += chunk # type: ignore[operator]
if final:
async for output in self.astream(final, config):
async for output in self.astream(final, config, **kwargs):
yield output
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
@ -217,7 +236,11 @@ class Runnable(Generic[Input, Output], ABC):
def _call_with_config(
self,
func: Callable[[Input], Output],
func: Union[
Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
@ -238,7 +261,16 @@ class Runnable(Generic[Input, Output], ABC):
run_type=run_type,
)
try:
output = func(input)
if accepts_run_manager_and_config(func):
output = func(
input,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
else:
output = func(input) # type: ignore[call-arg]
except Exception as e:
run_manager.on_chain_error(e)
raise
@ -253,7 +285,14 @@ class Runnable(Generic[Input, Output], ABC):
async def _acall_with_config(
self,
func: Callable[[Input], Awaitable[Output]],
func: Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
@ -274,7 +313,19 @@ class Runnable(Generic[Input, Output], ABC):
run_type=run_type,
)
try:
output = await func(input)
if accepts_run_manager_and_config(func):
output = await func(
input,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = await func(
input,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
output = await func(input) # type: ignore[call-arg]
except Exception as e:
await run_manager.on_chain_error(e)
raise
@ -290,7 +341,18 @@ class Runnable(Generic[Input, Output], ABC):
def _transform_stream_with_config(
self,
input: Iterator[Input],
transformer: Callable[[Iterator[Input]], Iterator[Output]],
transformer: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
Callable[
[
Iterator[Input],
CallbackManagerForChainRun,
RunnableConfig,
],
Iterator[Output],
],
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Iterator[Output]:
@ -319,7 +381,20 @@ class Runnable(Generic[Input, Output], ABC):
run_type=run_type,
)
try:
for chunk in transformer(input_for_transform):
if accepts_run_manager_and_config(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
iterator = transformer(input_for_transform) # type: ignore[call-arg]
for chunk in iterator:
yield chunk
if final_output_supported:
if final_output is None:
@ -361,7 +436,21 @@ class Runnable(Generic[Input, Output], ABC):
async def _atransform_stream_with_config(
self,
input: AsyncIterator[Input],
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
transformer: Union[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Callable[
[AsyncIterator[Input], AsyncCallbackManagerForChainRun],
AsyncIterator[Output],
],
Callable[
[
AsyncIterator[Input],
AsyncCallbackManagerForChainRun,
RunnableConfig,
],
AsyncIterator[Output],
],
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> AsyncIterator[Output]:
@ -390,7 +479,22 @@ class Runnable(Generic[Input, Output], ABC):
run_type=run_type,
)
try:
async for chunk in transformer(input_for_transform):
# mypy can't quite work out thew type guard here, but this is safe,
# check implementations of the accepts_* functions
if accepts_run_manager_and_config(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
iterator = transformer(input_for_transform) # type: ignore[call-arg]
async for chunk in iterator:
yield chunk
if final_output_supported:
if final_output is None:
@ -700,7 +804,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
if isinstance(other, RunnableSequence):
@ -721,7 +825,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
if isinstance(other, RunnableSequence):
@ -875,7 +979,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks
@ -1085,6 +1188,21 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
class RunnableMapChunk(Dict[str, Any]):
"""
Partial output from a RunnableMap
"""
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = copy.deepcopy(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] += other[key]
return chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
@ -1134,7 +1252,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# gather results from all steps
try:
@ -1177,7 +1297,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input}
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# gather results from all steps
@ -1203,6 +1323,134 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
await run_manager.on_chain_end(output)
return output
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[RunnableMapChunk]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
with ThreadPoolExecutor() as executor:
# Create the transform() generator for each step
named_generators = [
(
name,
step.transform(
input_copies.pop(),
patch_config(config, run_manager.get_child()),
),
)
for name, step in steps.items()
]
# Start the first iteration of each generator
futures = {
executor.submit(next, generator): (step_name, generator)
for step_name, generator in named_generators
}
# Yield chunks from each as they become available,
# and start the next iteration of that generator that yielded it.
# When all generators are exhausted, stop.
while futures:
completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED)
for future in completed_futures:
(step_name, generator) = futures.pop(future)
try:
chunk = RunnableMapChunk({step_name: future.result()})
yield chunk
futures[executor.submit(next, generator)] = (
step_name,
generator,
)
except StopIteration:
pass
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Dict[str, Any]]:
yield from self.transform(iter([input]), config)
async def _atransform(
self,
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[RunnableMapChunk]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
# Create the transform() generator for each step
named_generators = [
(
name,
step.atransform(
input_copies.pop(), patch_config(config, run_manager.get_child())
),
)
for name, step in steps.items()
]
# Wrap in a coroutine to satisfy linter
async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]:
return await py_anext(generator)
# Start the first iteration of each generator
tasks = {
asyncio.create_task(get_next_chunk(generator)): (step_name, generator)
for step_name, generator in named_generators
}
# Yield chunks from each as they become available,
# and start the next iteration of the generator that yielded it.
# When all generators are exhausted, stop.
while tasks:
completed_tasks, _ = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in completed_tasks:
(step_name, generator) = tasks.pop(task)
try:
chunk = RunnableMapChunk({step_name: task.result()})
yield chunk
new_task = asyncio.create_task(get_next_chunk(generator))
tasks[new_task] = (step_name, generator)
except StopAsyncIteration:
pass
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
yield chunk
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
async for chunk in self.atransform(input_aiter(), config):
yield chunk
class RunnableLambda(Runnable[Input, Output]):
"""
@ -1293,14 +1541,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item
def transform(
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
yield from self.bound.transform(input, config, **self.kwargs)
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
async def atransform(
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(input, config, **self.kwargs):
async for item in self.bound.atransform(
input, config, **{**self.kwargs, **kwargs}
):
yield item
@ -1316,7 +1572,7 @@ def coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
Mapping[str, Any],
]
) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
@ -1324,7 +1580,9 @@ def coerce_to_runnable(
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
runnables: Mapping[str, Runnable[Any, Any]] = {
key: coerce_to_runnable(r) for key, r in thing.items()
}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else:
raise TypeError(

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import AsyncIterator, Iterator, List, Optional
from typing import Any, AsyncIterator, Iterator, List, Optional
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable
@ -32,16 +32,22 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
return self._call_with_config(identity, input, config)
async def ainvoke(
self, input: Input, config: RunnableConfig | None = None
self, input: Input, config: Optional[RunnableConfig] = None
) -> Input:
return await self._acall_with_config(aidentity, input, config)
def transform(
self, input: Iterator[Input], config: RunnableConfig | None = None
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Input]:
return self._transform_stream_with_config(input, identity, config)
def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Input]:
return self._atransform_stream_with_config(input, identity, config)

@ -1,7 +1,8 @@
from __future__ import annotations
import asyncio
from typing import Any, Coroutine, Union
from inspect import signature
from typing import Any, Callable, Coroutine, Union
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
@ -16,3 +17,17 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
try:
return signature(callable).parameters.get("run_manager") is not None
except ValueError:
return False
def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool:
return (
accepts_run_manager(callable)
and signature(callable).parameters.get("config") is not None
)

@ -7,6 +7,7 @@ MIT License
from collections import deque
from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Awaitable,
@ -15,6 +16,7 @@ from typing import (
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
@ -64,33 +66,45 @@ def py_anext(
return anext_impl()
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
async def __aenter__(self) -> None:
pass
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
return False
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
async with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
# this peer is done remove its buffer
@ -145,6 +159,8 @@ class Tee(Generic[T]):
self,
iterable: AsyncIterator[T],
n: int = 2,
*,
lock: Optional[AsyncContextManager[Any]] = None,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
@ -153,6 +169,7 @@ class Tee(Generic[T]):
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
)
for buffer in self._buffers
)

@ -0,0 +1,162 @@
from collections import deque
from typing import (
Any,
ContextManager,
Deque,
Generator,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import Literal
T = TypeVar("T")
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
return False
def tee_peer(
iterator: Iterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: ContextManager[Any],
) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = next(iterator)
except StopIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "close"):
iterator.close()
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but share the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
.. code-block:: python3
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: Iterator[T],
n: int = 2,
*,
lock: Optional[ContextManager[Any]] = None,
):
self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> Iterator[T]:
...
@overload
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]:
...
def __getitem__(
self, item: Union[int, slice]
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[Iterator[T]]:
yield from self._children
def __enter__(self) -> "Tee[T]":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
self.close()
return False
def close(self) -> None:
for child in self._children:
child.close()
# Why this is needed https://stackoverflow.com/a/44638570
safetee = Tee

@ -528,7 +528,7 @@ Question:
parser = CommaSeparatedListOutputParser()
chain = (
chain: Runnable = (
{
"question": RunnablePassthrough[str]() | passthrough,
"documents": passthrough | retriever,
@ -760,6 +760,188 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert len(map_run.child_runs) == 3
def test_map_stream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat_res = "i'm a chatbot"
# sleep to better simulate a real stream
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
chain: Runnable = prompt | {
"chat": chat.bind(stop=["Thought:"]),
"llm": llm,
"passthrough": RunnablePassthrough(),
}
stream = chain.stream({"question": "What is your name?"})
final_value = None
streamed_chunks = []
for chunk in stream:
streamed_chunks.append(chunk)
if final_value is None:
final_value = chunk
else:
final_value += chunk
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": "i"},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None
assert final_value.get("chat").content == "i'm a chatbot"
assert final_value.get("llm") == "i'm a textbot"
assert final_value.get("passthrough") == prompt.invoke(
{"question": "What is your name?"}
)
def test_map_stream_iterator_input() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat_res = "i'm a chatbot"
# sleep to better simulate a real stream
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
chain: Runnable = (
prompt
| llm
| {
"chat": chat.bind(stop=["Thought:"]),
"llm": llm,
"passthrough": RunnablePassthrough(),
}
)
stream = chain.stream({"question": "What is your name?"})
final_value = None
streamed_chunks = []
for chunk in stream:
streamed_chunks.append(chunk)
if final_value is None:
final_value = chunk
else:
final_value += chunk
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None
assert final_value.get("chat").content == "i'm a chatbot"
assert final_value.get("llm") == "i'm a textbot"
assert final_value.get("passthrough") == "i'm a textbot"
@pytest.mark.asyncio
async def test_map_astream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat_res = "i'm a chatbot"
# sleep to better simulate a real stream
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
chain: Runnable = prompt | {
"chat": chat.bind(stop=["Thought:"]),
"llm": llm,
"passthrough": RunnablePassthrough(),
}
stream = chain.astream({"question": "What is your name?"})
final_value = None
streamed_chunks = []
async for chunk in stream:
streamed_chunks.append(chunk)
if final_value is None:
final_value = chunk
else:
final_value += chunk
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": "i"},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None
assert final_value.get("chat").content == "i'm a chatbot"
assert final_value.get("llm") == "i'm a textbot"
assert final_value.get("passthrough") == prompt.invoke(
{"question": "What is your name?"}
)
@pytest.mark.asyncio
async def test_map_astream_iterator_input() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat_res = "i'm a chatbot"
# sleep to better simulate a real stream
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
chain: Runnable = (
prompt
| llm
| {
"chat": chat.bind(stop=["Thought:"]),
"llm": llm,
"passthrough": RunnablePassthrough(),
}
)
stream = chain.astream({"question": "What is your name?"})
final_value = None
streamed_chunks = []
async for chunk in stream:
streamed_chunks.append(chunk)
if final_value is None:
final_value = chunk
else:
final_value += chunk
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None
assert final_value.get("chat").content == "i'm a chatbot"
assert final_value.get("llm") == "i'm a textbot"
assert final_value.get("passthrough") == llm_res
def test_bind_bind() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])

Loading…
Cancel
Save