Adds transform support for runnables (#8762)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8965/head
Nuno Campos 1 year ago committed by GitHub
parent 4d72288487
commit b8df15cd64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,6 +47,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run = self.run_map[str(run.parent_run_id)]
if parent_run:
self._add_child_run(parent_run, run)
parent_run.child_execution_order = max(
parent_run.child_execution_order, run.child_execution_order
)
else:
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
self.run_map[str(run.id)] = run
@ -254,7 +257,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._on_chain_start(chain_run)
def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""End a trace for a chain run."""
if not run_id:
@ -266,6 +274,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
chain_run.outputs = outputs
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
self._end_trace(chain_run)
self._on_chain_end(chain_run)
@ -273,6 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self,
error: Union[Exception, KeyboardInterrupt],
*,
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> None:
@ -286,6 +297,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
self._end_trace(chain_run)
self._on_chain_error(chain_run)

@ -1,9 +1,11 @@
"""Base interface that all chains should implement."""
import asyncio
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@ -55,18 +57,26 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
"""
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self(input, **(config or {}))
return self(input, **(config or {}), **kwargs)
async def ainvoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await super().ainvoke(input, config)
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await self.acall(input, **(config or {}))
return await self.acall(input, **(config or {}), **kwargs)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.

@ -3,6 +3,8 @@ import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@ -27,9 +29,11 @@ class TransformChain(Chain):
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform: Callable[[Dict[str, str]], Dict[str, str]]
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
atransform_cb: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = Field(None, alias="atransform")
"""The async coroutine transform function."""
@staticmethod
@ -62,18 +66,18 @@ class TransformChain(Chain):
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform(inputs)
return self.transform_cb(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform is not None:
return await self.atransform(inputs)
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform(inputs)
return self.transform_cb(inputs)

@ -1,9 +1,13 @@
"""Fake ChatModel for testing purposes."""
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGenerationChunk
class FakeListChatModel(SimpleChatModel):
@ -31,6 +35,36 @@ class FakeListChatModel(SimpleChatModel):
self.i = 0
return response
def _stream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"responses": self.responses}

@ -1,10 +1,12 @@
from typing import Any, List, Mapping, Optional
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.runnable import RunnableConfig
class FakeListLLM(LLM):
@ -51,3 +53,29 @@ class FakeListLLM(LLM):
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"responses": self.responses}
class FakeStreamingListLLM(FakeListLLM):
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
result = self.invoke(input, config)
for c in result:
yield c
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
result = await self.ainvoke(input, config)
for c in result:
yield c

@ -2,7 +2,17 @@ from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
@ -47,7 +57,7 @@ class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
):
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
@ -115,7 +125,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
""" # noqa: E501
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
@ -242,8 +252,47 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
return output_parser_dict
class StrOutputParser(BaseOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string.."""
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@property
def lc_serializable(self) -> bool:

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from itertools import tee
from typing import (
Any,
AsyncIterator,
@ -29,6 +30,7 @@ from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.utils.aiter import atee, py_anext
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
@ -92,6 +94,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
""" --- Public API --- """
@abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
...
@ -99,6 +103,10 @@ class Runnable(Generic[Input, Output], ABC):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
"""
Default implementation of ainvoke, which calls invoke in a thread pool.
Subclasses should override this method if they can run asynchronously.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.invoke, input, config
)
@ -110,6 +118,10 @@ class Runnable(Generic[Input, Output], ABC):
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
"""
Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
configs = self._get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor
@ -126,6 +138,10 @@ class Runnable(Generic[Input, Output], ABC):
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
"""
Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
@ -134,22 +150,90 @@ class Runnable(Generic[Input, Output], ABC):
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
"""
Default implementation of stream, which calls invoke.
Subclasses should override this method if they support streaming output.
"""
yield self.invoke(input, config)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
"""
Default implementation of astream, which calls ainvoke.
Subclasses should override this method if they support streaming output.
"""
yield await self.ainvoke(input, config)
def transform(
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
"""
Default implementation of transform, which buffers input and then calls stream.
Subclasses should override this method if they can start producing output while
input is still being generated.
"""
final: Union[Input, None] = None
for chunk in input:
if final is None:
final = chunk
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
if final:
yield from self.stream(final, config)
async def atransform(
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
"""
Default implementation of atransform, which buffers input and calls astream.
Subclasses should override this method if they can start producing output while
input is still being generated.
"""
final: Union[Input, None] = None
async for chunk in input:
if final is None:
final = chunk
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
if final:
async for output in self.astream(final, config):
yield output
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""
Bind arguments to a Runnable, returning a new Runnable.
"""
return RunnableBinding(bound=self, kwargs=kwargs)
def with_fallbacks(
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
)
""" --- Helper methods for Subclasses --- """
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
@ -169,6 +253,8 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
from langchain.callbacks.manager import CallbackManager
config = config or {}
@ -200,6 +286,8 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
config = config or {}
@ -224,20 +312,154 @@ class Runnable(Generic[Input, Output], ABC):
)
return output
def with_fallbacks(
def _transform_stream_with_config(
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
input: Iterator[Input],
transformer: Callable[[Iterator[Input]], Iterator[Output]],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
from langchain.callbacks.manager import CallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
final_input: Optional[Input] = next(input_for_tracing, None)
final_input_supported = True
final_output: Optional[Output] = None
final_output_supported = True
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
run_type=run_type,
)
try:
for chunk in transformer(input_for_transform):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
except TypeError:
final_output = None
final_output_supported = False
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
except Exception as e:
run_manager.on_chain_error(
e,
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
raise
else:
run_manager.on_chain_end(
final_output
if isinstance(final_output, dict)
else {"output": final_output},
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
async def _atransform_stream_with_config(
self,
input: AsyncIterator[Input],
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
final_input_supported = True
final_output: Optional[Output] = None
final_output_supported = True
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
run_type=run_type,
)
try:
async for chunk in transformer(input_for_transform):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
except TypeError:
final_output = None
final_output_supported = False
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
except Exception as e:
await run_manager.on_chain_error(
e,
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
raise
else:
await run_manager.on_chain_end(
final_output
if isinstance(final_output, dict)
else {"output": final_output},
inputs=final_input
if isinstance(final_input, dict)
else {"input": final_input},
)
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
@ -467,6 +689,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class RunnableSequence(Serializable, Runnable[Input, Output]):
"""
A sequence of runnables, where the output of each is the input of the next.
"""
first: Runnable[Input, Any]
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
last: Runnable[Any, Output]
@ -738,9 +964,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).transform != Runnable.transform:
streaming_start_index = i - 1
else:
break
# invoke the first steps
try:
for step in [self.first] + self.middle:
for step in steps[0:streaming_start_index]:
input = step.invoke(
input,
# mark each step as a child run
@ -750,15 +985,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager.on_chain_error(e)
raise
# stream the last step
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
for output in self.last.stream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream(
input, _patch_config(config, run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
final_pipeline, _patch_config(config, run_manager.get_child())
)
for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
@ -801,9 +1041,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).transform != Runnable.transform:
streaming_start_index = i - 1
else:
break
# invoke the first steps
try:
for step in [self.first] + self.middle:
for step in steps[0:streaming_start_index]:
input = await step.ainvoke(
input,
# mark each step as a child run
@ -813,15 +1062,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
await run_manager.on_chain_error(e)
raise
# stream the last step
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
async for output in self.last.astream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream(
input, _patch_config(config, run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
final_pipeline, _patch_config(config, run_manager.get_child())
)
async for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
@ -845,6 +1099,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
and returns a mapping of their outputs.
"""
steps: Mapping[str, Runnable[Input, Any]]
def __init__(
@ -957,6 +1216,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class RunnableLambda(Runnable[Input, Output]):
"""
A runnable that runs a callable.
"""
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
self.func = func
@ -977,6 +1240,10 @@ class RunnableLambda(Runnable[Input, Output]):
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
"""
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
return True
@ -986,6 +1253,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnableBinding(Serializable, Runnable[Input, Output]):
"""
A runnable that binds a runnable to a set of kwargs.
"""
bound: Runnable[Input, Output]
kwargs: Mapping[str, Any]
@ -1041,6 +1312,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
async for item in self.bound.astream(input, config, **self.kwargs):
yield item
def transform(
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield from self.bound.transform(input, config, **self.kwargs)
async def atransform(
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(input, config, **self.kwargs):
yield item
class RouterInput(TypedDict):
key: str
@ -1050,6 +1332,11 @@ class RouterInput(TypedDict):
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:

@ -0,0 +1,191 @@
"""
Adapted from
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License
"""
from collections import deque
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Deque,
Generic,
Iterator,
List,
Tuple,
TypeVar,
Union,
cast,
overload,
)
T = TypeVar("T")
_no_default = object()
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
# before 3.10, the builtin anext() was not available
def py_anext(
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
) -> Awaitable[Union[T, None, Any]]:
"""Pure-Python implementation of anext() for testing purposes.
Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""
try:
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
if default is _no_default:
return __anext__(iterator)
async def anext_impl() -> Union[T, Any]:
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default
return anext_impl()
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "aclose"):
await iterator.aclose()
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but share the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
.. code-block:: python3
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: AsyncIterator[T],
n: int = 2,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> AsyncIterator[T]:
...
@overload
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
yield from self._children
async def __aenter__(self) -> "Tee[T]":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
await self.aclose()
return False
async def aclose(self) -> None:
for child in self._children:
await child.aclose()
atee = Tee

@ -15,7 +15,7 @@ def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
return outputs
def test_tranform_chain() -> None:
def test_transform_chain() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],

File diff suppressed because one or more lines are too long

@ -11,7 +11,7 @@ from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain.load.dump import dumpd, dumps
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.chat import (
@ -22,6 +22,7 @@ from langchain.prompts.chat import (
)
from langchain.schema.document import Document
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
RouterRunnable,
@ -61,6 +62,8 @@ class FakeTracer(BaseTracer):
if run.parent_run_id
else None,
"child_runs": [self._copy_run(child) for child in run.child_runs],
"execution_order": None,
"child_execution_order": None,
}
)
@ -302,7 +305,7 @@ async def test_prompt_with_chat_model(
tracer = FakeTracer()
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [AIMessage(content="foo")]
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -678,7 +681,12 @@ async def test_router_runnable(
"key": "math",
"input": {"question": "2 + 2"},
}
assert tracer.runs == snapshot
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter
assert len(router_run.child_runs) == 2
@freeze_time("2023-01-01")
@ -758,6 +766,45 @@ def test_bind_bind() -> None:
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
def test_deep_stream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = prompt | llm | StrOutputParser()
stream = chain.stream({"question": "What up"})
chunks = []
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.mark.asyncio
async def test_deep_astream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = prompt | llm | StrOutputParser()
stream = chain.astream({"question": "What up"})
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)

Loading…
Cancel
Save