From fb66b392c61d838b3354016b6bb30925823cb994 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 20:12:48 +0100 Subject: [PATCH] Implement RunnablePassthrough.assign(...) (#11222) Passes through dict input and assigns additional keys --- .../langchain/schema/runnable/base.py | 39 +--- .../langchain/schema/runnable/passthrough.py | 200 +++++++++++++++++- .../langchain/schema/runnable/utils.py | 71 ++++++- .../schema/runnable/test_runnable.py | 198 +++++++++++++++++ 4 files changed, 475 insertions(+), 33 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 9b69a9f711..7286333cf4 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -53,6 +53,7 @@ from langchain.schema.runnable.config import ( patch_config, ) from langchain.schema.runnable.utils import ( + AddableDict, Input, Output, accepts_config, @@ -1748,30 +1749,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): yield chunk -class RunnableMapChunk(Dict[str, Any]): - """ - Partial output from a RunnableMap - """ - - def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk: - chunk = RunnableMapChunk(self) - for key in other: - if key not in chunk or chunk[key] is None: - chunk[key] = other[key] - elif other[key] is not None: - chunk[key] = chunk[key] + other[key] - return chunk - - def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: - chunk = RunnableMapChunk(other) - for key in self: - if key not in chunk or chunk[key] is None: - chunk[key] = self[key] - elif self[key] is not None: - chunk[key] = chunk[key] + self[key] - return chunk - - class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): """ A runnable that runs a mapping of runnables in parallel, @@ -1814,7 +1791,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): @property def input_schema(self) -> type[BaseModel]: - if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()): + if all( + s.input_schema.schema().get("type", "object") == "object" + for s in self.steps.values() + ): # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "RunnableMapInput", @@ -1822,6 +1802,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): k: (v.type_, v.default) for step in self.steps.values() for k, v in step.input_schema.__fields__.items() + if k != "__root__" }, ) @@ -1934,7 +1915,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): input: Iterator[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, - ) -> Iterator[RunnableMapChunk]: + ) -> Iterator[AddableDict]: # Shallow copy steps to ignore mutations while in progress steps = dict(self.steps) # Each step gets a copy of the input iterator, @@ -1967,7 +1948,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): for future in completed_futures: (step_name, generator) = futures.pop(future) try: - chunk = RunnableMapChunk({step_name: future.result()}) + chunk = AddableDict({step_name: future.result()}) yield chunk futures[executor.submit(next, generator)] = ( step_name, @@ -1999,7 +1980,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): input: AsyncIterator[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, - ) -> AsyncIterator[RunnableMapChunk]: + ) -> AsyncIterator[AddableDict]: # Shallow copy steps to ignore mutations while in progress steps = dict(self.steps) # Each step gets a copy of the input iterator, @@ -2038,7 +2019,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): for task in completed_tasks: (step_name, generator) = tasks.pop(task) try: - chunk = RunnableMapChunk({step_name: task.result()}) + chunk = AddableDict({step_name: task.result()}) yield chunk new_task = asyncio.create_task(get_next_chunk(generator)) tasks[new_task] = (step_name, generator) diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 5bcead7d95..18afe82591 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -1,10 +1,28 @@ from __future__ import annotations -from typing import Any, AsyncIterator, Iterator, List, Optional, Type +import asyncio +import threading +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Type, + Union, + cast, +) from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Runnable -from langchain.schema.runnable.config import RunnableConfig +from langchain.pydantic_v1 import BaseModel, create_model +from langchain.schema.runnable.base import Input, Runnable, RunnableMap +from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config +from langchain.schema.runnable.utils import AddableDict +from langchain.utils.aiter import atee, py_anext +from langchain.utils.iter import safetee def identity(x: Input) -> Input: @@ -38,6 +56,30 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): def OutputType(self) -> Any: return self.input_type or Any + @classmethod + def assign( + cls, + **kwargs: Union[ + Runnable[Dict[str, Any], Any], + Callable[[Dict[str, Any]], Any], + Mapping[ + str, + Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], + ], + ], + ) -> RunnableAssign: + """ + Merge the Dict input with the output produced by the mapping argument. + + Args: + mapping: A mapping from keys to runnables or callables. + + Returns: + A runnable that merges the Dict input with the output produced by the + mapping argument. + """ + return RunnableAssign(RunnableMap(kwargs)) + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: return self._call_with_config(identity, input, config) @@ -65,3 +107,155 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): ) -> AsyncIterator[Input]: async for chunk in self._atransform_stream_with_config(input, identity, config): yield chunk + + +class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]): + """ + A runnable that assigns key-value pairs to Dict[str, Any] inputs. + """ + + mapper: RunnableMap[Dict[str, Any]] + + def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None: + super().__init__(mapper=mapper, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + @property + def input_schema(self) -> type[BaseModel]: + map_input_schema = self.mapper.input_schema + if not map_input_schema.__custom_root_type__: + # ie. it's a dict + return map_input_schema + + return super().input_schema + + @property + def output_schema(self) -> type[BaseModel]: + map_input_schema = self.mapper.input_schema + map_output_schema = self.mapper.output_schema + if ( + not map_input_schema.__custom_root_type__ + and not map_output_schema.__custom_root_type__ + ): + # ie. both are dicts + return create_model( # type: ignore[call-overload] + "RunnableAssignOutput", + **{ + k: (v.type_, v.default) + for s in (map_input_schema, map_output_schema) + for k, v in s.__fields__.items() + }, + ) + + return super().output_schema + + def invoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + assert isinstance(input, dict) + return { + **input, + **self.mapper.invoke(input, config, **kwargs), + } + + async def ainvoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + assert isinstance(input, dict) + return { + **input, + **await self.mapper.ainvoke(input, config, **kwargs), + } + + def transform( + self, + input: Iterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + # collect mapper keys + mapper_keys = set(self.mapper.steps.keys()) + # create two streams, one for the map and one for the passthrough + for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) + # create map output stream + map_output = self.mapper.transform(for_map, config, **kwargs) + # get executor to start map output stream in background + with get_executor_for_config(config or {}) as executor: + # start map output stream + first_map_chunk_future = executor.submit(next, map_output) # type: ignore + # consume passthrough stream + for chunk in for_passthrough: + assert isinstance(chunk, dict) + # remove mapper keys from passthrough chunk, to be overwritten by map + filtered = AddableDict( + {k: v for k, v in chunk.items() if k not in mapper_keys} + ) + if filtered: + yield filtered + # yield map output + yield cast(Dict[str, Any], first_map_chunk_future.result()) + for chunk in map_output: + yield chunk + + async def atransform( + self, + input: AsyncIterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + # collect mapper keys + mapper_keys = set(self.mapper.steps.keys()) + # create two streams, one for the map and one for the passthrough + for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) + # create map output stream + map_output = self.mapper.atransform(for_map, config, **kwargs) + # start map output stream + first_map_chunk_task: asyncio.Task = asyncio.create_task( + py_anext(map_output), # type: ignore[arg-type] + ) + # consume passthrough stream + async for chunk in for_passthrough: + assert isinstance(chunk, dict) + # remove mapper keys from passthrough chunk, to be overwritten by map output + filtered = AddableDict( + {k: v for k, v in chunk.items() if k not in mapper_keys} + ) + if filtered: + yield filtered + # yield map output + yield await first_map_chunk_task + async for chunk in map_output: + yield chunk + + def stream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + return self.transform(iter([input]), config, **kwargs) + + async def astream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async def input_aiter() -> AsyncIterator[Dict[str, Any]]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index b312c8d917..37403f8ea3 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -5,7 +5,20 @@ import asyncio import inspect import textwrap from inspect import signature -from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union +from typing import ( + Any, + AsyncIterable, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Protocol, + Set, + TypeVar, + Union, +) Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do @@ -142,3 +155,59 @@ def indent_lines_after_first(text: str, prefix: str) -> str: spaces = " " * n_spaces lines = text.splitlines() return "\n".join([lines[0]] + [spaces + line for line in lines[1:]]) + + +class AddableDict(Dict[str, Any]): + """ + Dictionary that can be added to another dictionary. + """ + + def __add__(self, other: AddableDict) -> AddableDict: + chunk = AddableDict(self) + for key in other: + if key not in chunk or chunk[key] is None: + chunk[key] = other[key] + elif other[key] is not None: + chunk[key] = chunk[key] + other[key] + return chunk + + def __radd__(self, other: AddableDict) -> AddableDict: + chunk = AddableDict(other) + for key in self: + if key not in chunk or chunk[key] is None: + chunk[key] = self[key] + elif self[key] is not None: + chunk[key] = chunk[key] + self[key] + return chunk + + +_T_co = TypeVar("_T_co", covariant=True) +_T_contra = TypeVar("_T_contra", contravariant=True) + + +class SupportsAdd(Protocol[_T_contra, _T_co]): + def __add__(self, __x: _T_contra) -> _T_co: + ... + + +Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any]) + + +def add(addables: Iterable[Addable]) -> Optional[Addable]: + final = None + for chunk in addables: + if final is None: + final = chunk + else: + final = final + chunk + return final + + +async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: + final = None + async for chunk in addables: + if final is None: + final = chunk + else: + final = final + chunk + return final diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 29df3726f9..b3484a8d8a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -57,6 +57,7 @@ from langchain.schema.runnable import ( RunnableWithFallbacks, ) from langchain.schema.runnable.base import RunnableGenerator +from langchain.schema.runnable.utils import add from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -2018,6 +2019,104 @@ def test_deep_stream() -> None: assert "".join(chunks) == "foo-lish" +def test_deep_stream_assign() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain: Runnable = prompt | llm | {"str": StrOutputParser()} + + stream = chain.stream({"question": "What up"}) + + chunks = [] + for chunk in stream: + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + assert add(chunks) == {"str": "foo-lish"} + + chain_with_assign = chain | RunnablePassthrough.assign( + hello=itemgetter("str") | llm + ) + + assert chain_with_assign.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"question": {"title": "Question"}}, + } + assert chain_with_assign.output_schema.schema() == { + "title": "RunnableAssignOutput", + "type": "object", + "properties": { + "str": {"title": "Str"}, + "hello": {"title": "Hello", "type": "string"}, + }, + } + + chunks = [] + for chunk in chain_with_assign.stream({"question": "What up"}): + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") * 2 + assert chunks == [ + # first stream passthrough input chunks + {"str": "f"}, + {"str": "o"}, + {"str": "o"}, + {"str": "-"}, + {"str": "l"}, + {"str": "i"}, + {"str": "s"}, + {"str": "h"}, + # then stream assign output chunks + {"hello": "f"}, + {"hello": "o"}, + {"hello": "o"}, + {"hello": "-"}, + {"hello": "l"}, + {"hello": "i"}, + {"hello": "s"}, + {"hello": "h"}, + ] + assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"} + assert chain_with_assign.invoke({"question": "What up"}) == { + "str": "foo-lish", + "hello": "foo-lish", + } + + chain_with_assign_shadow = chain | RunnablePassthrough.assign( + str=lambda _: "shadow", + hello=itemgetter("str") | llm, + ) + + assert chain_with_assign_shadow.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"question": {"title": "Question"}}, + } + assert chain_with_assign_shadow.output_schema.schema() == { + "title": "RunnableAssignOutput", + "type": "object", + "properties": { + "str": {"title": "Str"}, + "hello": {"title": "Hello", "type": "string"}, + }, + } + + chunks = [] + for chunk in chain_with_assign_shadow.stream({"question": "What up"}): + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + 1 + assert add(chunks) == {"str": "shadow", "hello": "foo-lish"} + assert chain_with_assign_shadow.invoke({"question": "What up"}) == { + "str": "shadow", + "hello": "foo-lish", + } + + @pytest.mark.asyncio async def test_deep_astream() -> None: prompt = ( @@ -2045,6 +2144,105 @@ async def test_deep_astream() -> None: assert "".join(chunks) == "foo-lish" +@pytest.mark.asyncio +async def test_deep_astream_assign() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain: Runnable = prompt | llm | {"str": StrOutputParser()} + + stream = chain.astream({"question": "What up"}) + + chunks = [] + async for chunk in stream: + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + assert add(chunks) == {"str": "foo-lish"} + + chain_with_assign = chain | RunnablePassthrough.assign( + hello=itemgetter("str") | llm, + ) + + assert chain_with_assign.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"question": {"title": "Question"}}, + } + assert chain_with_assign.output_schema.schema() == { + "title": "RunnableAssignOutput", + "type": "object", + "properties": { + "str": {"title": "Str"}, + "hello": {"title": "Hello", "type": "string"}, + }, + } + + chunks = [] + async for chunk in chain_with_assign.astream({"question": "What up"}): + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") * 2 + assert chunks == [ + # first stream passthrough input chunks + {"str": "f"}, + {"str": "o"}, + {"str": "o"}, + {"str": "-"}, + {"str": "l"}, + {"str": "i"}, + {"str": "s"}, + {"str": "h"}, + # then stream assign output chunks + {"hello": "f"}, + {"hello": "o"}, + {"hello": "o"}, + {"hello": "-"}, + {"hello": "l"}, + {"hello": "i"}, + {"hello": "s"}, + {"hello": "h"}, + ] + assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"} + assert await chain_with_assign.ainvoke({"question": "What up"}) == { + "str": "foo-lish", + "hello": "foo-lish", + } + + chain_with_assign_shadow = chain | RunnablePassthrough.assign( + str=lambda _: "shadow", + hello=itemgetter("str") | llm, + ) + + assert chain_with_assign_shadow.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"question": {"title": "Question"}}, + } + assert chain_with_assign_shadow.output_schema.schema() == { + "title": "RunnableAssignOutput", + "type": "object", + "properties": { + "str": {"title": "Str"}, + "hello": {"title": "Hello", "type": "string"}, + }, + } + + chunks = [] + async for chunk in chain_with_assign_shadow.astream({"question": "What up"}): + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + 1 + assert add(chunks) == {"str": "shadow", "hello": "foo-lish"} + assert await chain_with_assign_shadow.ainvoke({"question": "What up"}) == { + "str": "shadow", + "hello": "foo-lish", + } + + def test_runnable_sequence_transform() -> None: llm = FakeStreamingListLLM(responses=["foo-lish"])