Implement RunnablePassthrough.assign(...) (#11222)

Passes through dict input and assigns additional keys

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11235/head
Nuno Campos 11 months ago committed by GitHub
parent 1ddf9f74b2
commit fb66b392c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,6 +53,7 @@ from langchain.schema.runnable.config import (
patch_config, patch_config,
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
AddableDict,
Input, Input,
Output, Output,
accepts_config, accepts_config,
@ -1748,30 +1749,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
yield chunk yield chunk
class RunnableMapChunk(Dict[str, Any]):
"""
Partial output from a RunnableMap
"""
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] = chunk[key] + other[key]
return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] = chunk[key] + self[key]
return chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
""" """
A runnable that runs a mapping of runnables in parallel, A runnable that runs a mapping of runnables in parallel,
@ -1814,7 +1791,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
@property @property
def input_schema(self) -> type[BaseModel]: def input_schema(self) -> type[BaseModel]:
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()): if all(
s.input_schema.schema().get("type", "object") == "object"
for s in self.steps.values()
):
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableMapInput", "RunnableMapInput",
@ -1822,6 +1802,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
k: (v.type_, v.default) k: (v.type_, v.default)
for step in self.steps.values() for step in self.steps.values()
for k, v in step.input_schema.__fields__.items() for k, v in step.input_schema.__fields__.items()
if k != "__root__"
}, },
) )
@ -1934,7 +1915,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
input: Iterator[Input], input: Iterator[Input],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> Iterator[RunnableMapChunk]: ) -> Iterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress # Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps) steps = dict(self.steps)
# Each step gets a copy of the input iterator, # Each step gets a copy of the input iterator,
@ -1967,7 +1948,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
for future in completed_futures: for future in completed_futures:
(step_name, generator) = futures.pop(future) (step_name, generator) = futures.pop(future)
try: try:
chunk = RunnableMapChunk({step_name: future.result()}) chunk = AddableDict({step_name: future.result()})
yield chunk yield chunk
futures[executor.submit(next, generator)] = ( futures[executor.submit(next, generator)] = (
step_name, step_name,
@ -1999,7 +1980,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
input: AsyncIterator[Input], input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> AsyncIterator[RunnableMapChunk]: ) -> AsyncIterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress # Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps) steps = dict(self.steps)
# Each step gets a copy of the input iterator, # Each step gets a copy of the input iterator,
@ -2038,7 +2019,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
for task in completed_tasks: for task in completed_tasks:
(step_name, generator) = tasks.pop(task) (step_name, generator) = tasks.pop(task)
try: try:
chunk = RunnableMapChunk({step_name: task.result()}) chunk = AddableDict({step_name: task.result()})
yield chunk yield chunk
new_task = asyncio.create_task(get_next_chunk(generator)) new_task = asyncio.create_task(get_next_chunk(generator))
tasks[new_task] = (step_name, generator) tasks[new_task] = (step_name, generator)

@ -1,10 +1,28 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, AsyncIterator, Iterator, List, Optional, Type import asyncio
import threading
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
Union,
cast,
)
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.base import Input, Runnable, RunnableMap
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
def identity(x: Input) -> Input: def identity(x: Input) -> Input:
@ -38,6 +56,30 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
def OutputType(self) -> Any: def OutputType(self) -> Any:
return self.input_type or Any return self.input_type or Any
@classmethod
def assign(
cls,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> RunnableAssign:
"""
Merge the Dict input with the output produced by the mapping argument.
Args:
mapping: A mapping from keys to runnables or callables.
Returns:
A runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableMap(kwargs))
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config) return self._call_with_config(identity, input, config)
@ -65,3 +107,155 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
) -> AsyncIterator[Input]: ) -> AsyncIterator[Input]:
async for chunk in self._atransform_stream_with_config(input, identity, config): async for chunk in self._atransform_stream_with_config(input, identity, config):
yield chunk yield chunk
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""
mapper: RunnableMap[Dict[str, Any]]
def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
map_input_schema = self.mapper.input_schema
if not map_input_schema.__custom_root_type__:
# ie. it's a dict
return map_input_schema
return super().input_schema
@property
def output_schema(self) -> type[BaseModel]:
map_input_schema = self.mapper.input_schema
map_output_schema = self.mapper.output_schema
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
):
# ie. both are dicts
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
)
return super().output_schema
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
return {
**input,
**self.mapper.invoke(input, config, **kwargs),
}
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
}
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(for_map, config, **kwargs)
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
first_map_chunk_future = executor.submit(next, map_output) # type: ignore
# consume passthrough stream
for chunk in for_passthrough:
assert isinstance(chunk, dict)
# remove mapper keys from passthrough chunk, to be overwritten by map
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(for_map, config, **kwargs)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output), # type: ignore[arg-type]
)
# consume passthrough stream
async for chunk in for_passthrough:
assert isinstance(chunk, dict)
# remove mapper keys from passthrough chunk, to be overwritten by map output
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield await first_map_chunk_task
async for chunk in map_output:
yield chunk
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk

@ -5,7 +5,20 @@ import asyncio
import inspect import inspect
import textwrap import textwrap
from inspect import signature from inspect import signature
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union from typing import (
Any,
AsyncIterable,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Protocol,
Set,
TypeVar,
Union,
)
Input = TypeVar("Input") Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
@ -142,3 +155,59 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
spaces = " " * n_spaces spaces = " " * n_spaces
lines = text.splitlines() lines = text.splitlines()
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]]) return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
class AddableDict(Dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""
def __add__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] = chunk[key] + other[key]
return chunk
def __radd__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] = chunk[key] + self[key]
return chunk
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
class SupportsAdd(Protocol[_T_contra, _T_co]):
def __add__(self, __x: _T_contra) -> _T_co:
...
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
def add(addables: Iterable[Addable]) -> Optional[Addable]:
final = None
for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
final = None
async for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final

@ -57,6 +57,7 @@ from langchain.schema.runnable import (
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.schema.runnable.base import RunnableGenerator from langchain.schema.runnable.base import RunnableGenerator
from langchain.schema.runnable.utils import add
from langchain.tools.base import BaseTool, tool from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -2018,6 +2019,104 @@ def test_deep_stream() -> None:
assert "".join(chunks) == "foo-lish" assert "".join(chunks) == "foo-lish"
def test_deep_stream_assign() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
stream = chain.stream({"question": "What up"})
chunks = []
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
chain_with_assign = chain | RunnablePassthrough.assign(
hello=itemgetter("str") | llm
)
assert chain_with_assign.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
for chunk in chain_with_assign.stream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") * 2
assert chunks == [
# first stream passthrough input chunks
{"str": "f"},
{"str": "o"},
{"str": "o"},
{"str": "-"},
{"str": "l"},
{"str": "i"},
{"str": "s"},
{"str": "h"},
# then stream assign output chunks
{"hello": "f"},
{"hello": "o"},
{"hello": "o"},
{"hello": "-"},
{"hello": "l"},
{"hello": "i"},
{"hello": "s"},
{"hello": "h"},
]
assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
assert chain_with_assign.invoke({"question": "What up"}) == {
"str": "foo-lish",
"hello": "foo-lish",
}
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
str=lambda _: "shadow",
hello=itemgetter("str") | llm,
)
assert chain_with_assign_shadow.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign_shadow.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
for chunk in chain_with_assign_shadow.stream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") + 1
assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
assert chain_with_assign_shadow.invoke({"question": "What up"}) == {
"str": "shadow",
"hello": "foo-lish",
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deep_astream() -> None: async def test_deep_astream() -> None:
prompt = ( prompt = (
@ -2045,6 +2144,105 @@ async def test_deep_astream() -> None:
assert "".join(chunks) == "foo-lish" assert "".join(chunks) == "foo-lish"
@pytest.mark.asyncio
async def test_deep_astream_assign() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
stream = chain.astream({"question": "What up"})
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
chain_with_assign = chain | RunnablePassthrough.assign(
hello=itemgetter("str") | llm,
)
assert chain_with_assign.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
async for chunk in chain_with_assign.astream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") * 2
assert chunks == [
# first stream passthrough input chunks
{"str": "f"},
{"str": "o"},
{"str": "o"},
{"str": "-"},
{"str": "l"},
{"str": "i"},
{"str": "s"},
{"str": "h"},
# then stream assign output chunks
{"hello": "f"},
{"hello": "o"},
{"hello": "o"},
{"hello": "-"},
{"hello": "l"},
{"hello": "i"},
{"hello": "s"},
{"hello": "h"},
]
assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
assert await chain_with_assign.ainvoke({"question": "What up"}) == {
"str": "foo-lish",
"hello": "foo-lish",
}
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
str=lambda _: "shadow",
hello=itemgetter("str") | llm,
)
assert chain_with_assign_shadow.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign_shadow.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
async for chunk in chain_with_assign_shadow.astream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") + 1
assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
assert await chain_with_assign_shadow.ainvoke({"question": "What up"}) == {
"str": "shadow",
"hello": "foo-lish",
}
def test_runnable_sequence_transform() -> None: def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"]) llm = FakeStreamingListLLM(responses=["foo-lish"])

Loading…
Cancel
Save