Implement RunnablePassthrough.assign(...) (#11222)

Passes through dict input and assigns additional keys

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11235/head
Nuno Campos 10 months ago committed by GitHub
parent 1ddf9f74b2
commit fb66b392c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,6 +53,7 @@ from langchain.schema.runnable.config import (
patch_config,
)
from langchain.schema.runnable.utils import (
AddableDict,
Input,
Output,
accepts_config,
@ -1748,30 +1749,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
yield chunk
class RunnableMapChunk(Dict[str, Any]):
"""
Partial output from a RunnableMap
"""
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] = chunk[key] + other[key]
return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] = chunk[key] + self[key]
return chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
@ -1814,7 +1791,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
@property
def input_schema(self) -> type[BaseModel]:
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
if all(
s.input_schema.schema().get("type", "object") == "object"
for s in self.steps.values()
):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
@ -1822,6 +1802,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
k: (v.type_, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
if k != "__root__"
},
)
@ -1934,7 +1915,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[RunnableMapChunk]:
) -> Iterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
# Each step gets a copy of the input iterator,
@ -1967,7 +1948,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
for future in completed_futures:
(step_name, generator) = futures.pop(future)
try:
chunk = RunnableMapChunk({step_name: future.result()})
chunk = AddableDict({step_name: future.result()})
yield chunk
futures[executor.submit(next, generator)] = (
step_name,
@ -1999,7 +1980,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[RunnableMapChunk]:
) -> AsyncIterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
# Each step gets a copy of the input iterator,
@ -2038,7 +2019,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
for task in completed_tasks:
(step_name, generator) = tasks.pop(task)
try:
chunk = RunnableMapChunk({step_name: task.result()})
chunk = AddableDict({step_name: task.result()})
yield chunk
new_task = asyncio.create_task(get_next_chunk(generator))
tasks[new_task] = (step_name, generator)

@ -1,10 +1,28 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Iterator, List, Optional, Type
import asyncio
import threading
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
Union,
cast,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
def identity(x: Input) -> Input:
@ -38,6 +56,30 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
def OutputType(self) -> Any:
return self.input_type or Any
@classmethod
def assign(
cls,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> RunnableAssign:
"""
Merge the Dict input with the output produced by the mapping argument.
Args:
mapping: A mapping from keys to runnables or callables.
Returns:
A runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableMap(kwargs))
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config)
@ -65,3 +107,155 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
) -> AsyncIterator[Input]:
async for chunk in self._atransform_stream_with_config(input, identity, config):
yield chunk
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""
mapper: RunnableMap[Dict[str, Any]]
def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
map_input_schema = self.mapper.input_schema
if not map_input_schema.__custom_root_type__:
# ie. it's a dict
return map_input_schema
return super().input_schema
@property
def output_schema(self) -> type[BaseModel]:
map_input_schema = self.mapper.input_schema
map_output_schema = self.mapper.output_schema
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
):
# ie. both are dicts
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
)
return super().output_schema
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
return {
**input,
**self.mapper.invoke(input, config, **kwargs),
}
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
}
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(for_map, config, **kwargs)
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
first_map_chunk_future = executor.submit(next, map_output) # type: ignore
# consume passthrough stream
for chunk in for_passthrough:
assert isinstance(chunk, dict)
# remove mapper keys from passthrough chunk, to be overwritten by map
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(for_map, config, **kwargs)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output), # type: ignore[arg-type]
)
# consume passthrough stream
async for chunk in for_passthrough:
assert isinstance(chunk, dict)
# remove mapper keys from passthrough chunk, to be overwritten by map output
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield await first_map_chunk_task
async for chunk in map_output:
yield chunk
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk

@ -5,7 +5,20 @@ import asyncio
import inspect
import textwrap
from inspect import signature
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
from typing import (
Any,
AsyncIterable,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Protocol,
Set,
TypeVar,
Union,
)
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
@ -142,3 +155,59 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
spaces = " " * n_spaces
lines = text.splitlines()
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
class AddableDict(Dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""
def __add__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] = chunk[key] + other[key]
return chunk
def __radd__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] = chunk[key] + self[key]
return chunk
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
class SupportsAdd(Protocol[_T_contra, _T_co]):
def __add__(self, __x: _T_contra) -> _T_co:
...
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
def add(addables: Iterable[Addable]) -> Optional[Addable]:
final = None
for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
final = None
async for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final

@ -57,6 +57,7 @@ from langchain.schema.runnable import (
RunnableWithFallbacks,
)
from langchain.schema.runnable.base import RunnableGenerator
from langchain.schema.runnable.utils import add
from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -2018,6 +2019,104 @@ def test_deep_stream() -> None:
assert "".join(chunks) == "foo-lish"
def test_deep_stream_assign() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
stream = chain.stream({"question": "What up"})
chunks = []
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
chain_with_assign = chain | RunnablePassthrough.assign(
hello=itemgetter("str") | llm
)
assert chain_with_assign.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
for chunk in chain_with_assign.stream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") * 2
assert chunks == [
# first stream passthrough input chunks
{"str": "f"},
{"str": "o"},
{"str": "o"},
{"str": "-"},
{"str": "l"},
{"str": "i"},
{"str": "s"},
{"str": "h"},
# then stream assign output chunks
{"hello": "f"},
{"hello": "o"},
{"hello": "o"},
{"hello": "-"},
{"hello": "l"},
{"hello": "i"},
{"hello": "s"},
{"hello": "h"},
]
assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
assert chain_with_assign.invoke({"question": "What up"}) == {
"str": "foo-lish",
"hello": "foo-lish",
}
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
str=lambda _: "shadow",
hello=itemgetter("str") | llm,
)
assert chain_with_assign_shadow.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign_shadow.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
for chunk in chain_with_assign_shadow.stream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") + 1
assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
assert chain_with_assign_shadow.invoke({"question": "What up"}) == {
"str": "shadow",
"hello": "foo-lish",
}
@pytest.mark.asyncio
async def test_deep_astream() -> None:
prompt = (
@ -2045,6 +2144,105 @@ async def test_deep_astream() -> None:
assert "".join(chunks) == "foo-lish"
@pytest.mark.asyncio
async def test_deep_astream_assign() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
stream = chain.astream({"question": "What up"})
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
chain_with_assign = chain | RunnablePassthrough.assign(
hello=itemgetter("str") | llm,
)
assert chain_with_assign.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
async for chunk in chain_with_assign.astream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") * 2
assert chunks == [
# first stream passthrough input chunks
{"str": "f"},
{"str": "o"},
{"str": "o"},
{"str": "-"},
{"str": "l"},
{"str": "i"},
{"str": "s"},
{"str": "h"},
# then stream assign output chunks
{"hello": "f"},
{"hello": "o"},
{"hello": "o"},
{"hello": "-"},
{"hello": "l"},
{"hello": "i"},
{"hello": "s"},
{"hello": "h"},
]
assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
assert await chain_with_assign.ainvoke({"question": "What up"}) == {
"str": "foo-lish",
"hello": "foo-lish",
}
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
str=lambda _: "shadow",
hello=itemgetter("str") | llm,
)
assert chain_with_assign_shadow.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"question": {"title": "Question"}},
}
assert chain_with_assign_shadow.output_schema.schema() == {
"title": "RunnableAssignOutput",
"type": "object",
"properties": {
"str": {"title": "Str"},
"hello": {"title": "Hello", "type": "string"},
},
}
chunks = []
async for chunk in chain_with_assign_shadow.astream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish") + 1
assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
assert await chain_with_assign_shadow.ainvoke({"question": "What up"}) == {
"str": "shadow",
"hello": "foo-lish",
}
def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])

Loading…
Cancel
Save