[core/minor] Runnables: Implement a context api (#14046)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Brace Sproul <braceasproul@gmail.com>
pull/14367/head
Nuno Campos 10 months ago committed by GitHub
parent 8f95a8206b
commit 77c38df36c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,6 +30,7 @@ from langchain_core.runnables.config import (
get_config_list,
patch_config,
)
from langchain_core.runnables.context import Context
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.router import RouterInput, RouterRunnable
@ -47,6 +48,7 @@ __all__ = [
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"Context",
"patch_config",
"RouterInput",
"RouterRunnable",

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy
from functools import partial, wraps
from itertools import tee
from itertools import groupby, tee
from operator import itemgetter
from typing import (
TYPE_CHECKING,
@ -22,6 +22,7 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
@ -1401,9 +1402,46 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps for spec in step.config_specs
from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX, _key_from_id
# get all specs
all_specs = [
(spec, idx)
for idx, step in enumerate(self.steps)
for spec in step.config_specs
]
# calculate context dependencies
specs_by_pos = groupby(
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
lambda x: x[1],
)
next_deps: Set[str] = set()
deps_by_pos: Dict[int, Set[str]] = {}
for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs}
# assign context dependencies
for pos, (spec, idx) in enumerate(all_specs):
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
all_specs[pos] = (
ConfigurableFieldSpec(
id=spec.id,
annotation=spec.annotation,
name=spec.name,
default=spec.default,
description=spec.description,
is_shared=spec.is_shared,
dependencies=[
d
for d in deps_by_pos[idx]
if _key_from_id(d) != _key_from_id(spec.id)
]
+ (spec.dependencies or []),
),
idx,
)
return get_unique_config_specs(spec for spec, _ in all_specs)
def __repr__(self) -> str:
return "\n| ".join(
@ -1456,8 +1494,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = ensure_config(config)
from langchain_core.runnables.context import config_with_context
# setup callbacks and context
config = config_with_context(ensure_config(config), self.steps)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -1488,8 +1528,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
from langchain_core.runnables.context import aconfig_with_context
# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -1523,12 +1565,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.runnables.context import config_with_context
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
# setup callbacks and context
configs = [
config_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -1641,15 +1687,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
)
from langchain_core.callbacks.manager import AsyncCallbackManager
from langchain_core.runnables.context import aconfig_with_context
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
# setup callbacks and context
configs = [
aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -1763,7 +1811,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[Output]:
from langchain_core.runnables.context import config_with_context
steps = [self.first] + self.middle + [self.last]
config = config_with_context(config, self.steps)
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
@ -1787,7 +1838,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[Output]:
from langchain_core.runnables.context import aconfig_with_context
steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)
# stream the last steps
# transform the input stream of each step with the next

@ -26,6 +26,10 @@ from langchain_core.runnables.config import (
get_callback_manager_for_config,
patch_config,
)
from langchain_core.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
@ -148,7 +152,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
specs = get_unique_config_specs(
spec
for step in (
[self.default]
@ -157,6 +161,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
)
for spec in step.config_specs
)
if any(
s.id.startswith(CONTEXT_CONFIG_PREFIX)
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
for s in specs
):
raise ValueError("RunnableBranch cannot contain context setters.")
return specs
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any

@ -0,0 +1,313 @@
import asyncio
import threading
from collections import defaultdict
from functools import partial
from itertools import groupby
from typing import (
Any,
Awaitable,
Callable,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)
from langchain_core.runnables.base import (
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import RunnableConfig, patch_config
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T")
Values = Dict[Union[asyncio.Event, threading.Event], Any]
CONTEXT_CONFIG_PREFIX = "__context__/"
CONTEXT_CONFIG_SUFFIX_GET = "/get"
CONTEXT_CONFIG_SUFFIX_SET = "/set"
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
async def _agetter(done: asyncio.Event, values: Values) -> Any:
await done.wait()
return values[done]
def _setter(done: threading.Event, values: Values, value: T) -> T:
values[done] = value
done.set()
return value
def _getter(done: threading.Event, values: Values) -> Any:
done.wait()
return values[done]
def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
raise ValueError(f"Invalid context config id {id_}")
def _config_with_context(
config: RunnableConfig,
steps: List[Runnable],
setter: Callable,
getter: Callable,
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
) -> RunnableConfig:
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
return config
context_specs = [
(spec, i)
for i, step in enumerate(steps)
for spec in step.config_specs
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
]
grouped_by_key = {
key: list(group)
for key, group in groupby(
sorted(context_specs, key=lambda s: s[0].id),
key=lambda s: _key_from_id(s[0].id),
)
}
deps_by_key = {
key: set(
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
)
for key, group in grouped_by_key.items()
}
values: Values = {}
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
event_cls
)
context_funcs: Dict[str, Callable[[], Any]] = {}
for key, group in grouped_by_key.items():
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
for dep in deps_by_key[key]:
if key in deps_by_key[dep]:
raise ValueError(
f"Deadlock detected between context keys {key} and {dep}"
)
if len(getters) < 1:
raise ValueError(f"Expected at least one getter for context key {key}")
if len(setters) != 1:
raise ValueError(f"Expected exactly one setter for context key {key}")
setter_idx = setters[0][1]
if any(getter_idx < setter_idx for _, getter_idx in getters):
raise ValueError(
f"Context setter for key {key} must be defined after all getters."
)
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
return patch_config(config, configurable=context_funcs)
def aconfig_with_context(
config: RunnableConfig,
steps: List[Runnable],
) -> RunnableConfig:
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
def config_with_context(
config: RunnableConfig,
steps: List[Runnable],
) -> RunnableConfig:
return _config_with_context(config, steps, _setter, _getter, threading.Event)
class ContextGet(RunnableSerializable):
prefix: str = ""
key: Union[str, List[str]]
@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
for k in keys
]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {}
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
else:
return configurable[self.ids[0]]()
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = config or {}
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return {key: value for key, value in zip(self.key, values)}
else:
return await configurable[self.ids[0]]()
SetValue = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Any,
]
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
if not isinstance(value, Runnable) and not callable(value):
return coerce_to_runnable(lambda _: value)
return coerce_to_runnable(value)
class ContextSet(RunnableSerializable):
prefix: str = ""
keys: Mapping[str, Optional[Runnable]]
class Config:
arbitrary_types_allowed = True
def __init__(
self,
key: Optional[str] = None,
value: Optional[SetValue] = None,
prefix: str = "",
**kwargs: SetValue,
):
if key is not None:
kwargs[key] = value
super().__init__(
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()
},
prefix=prefix,
)
@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
for key in self.keys
]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
mapper_config_specs = [
s
for mapper in self.keys.values()
if mapper is not None
for s in mapper.config_specs
]
for spec in mapper_config_specs:
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
getter_key = spec.id.split("/")[1]
if getter_key in self.keys:
raise ValueError(
f"Circular reference in context setter for key {getter_key}"
)
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
annotation=Callable[[], Any],
)
for id_ in self.ids
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {}
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
configurable[id_](mapper.invoke(input, config))
else:
configurable[id_](input)
return input
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = config or {}
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
await configurable[id_](await mapper.ainvoke(input, config))
else:
await configurable[id_](input)
return input
class Context:
@staticmethod
def create_scope(scope: str, /) -> "PrefixContext":
return PrefixContext(prefix=scope)
@staticmethod
def getter(key: Union[str, List[str]], /) -> ContextGet:
return ContextGet(key=key)
@staticmethod
def setter(
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix="", **kwargs)
class PrefixContext:
prefix: str = ""
def __init__(self, prefix: str = ""):
self.prefix = prefix
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix)
def setter(
self,
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)

@ -308,13 +308,16 @@ class ConfigurableFieldSpec(NamedTuple):
description: Optional[str] = None
default: Any = None
is_shared: bool = False
dependencies: Optional[List[str]] = None
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> List[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs."""
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
grouped = groupby(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
)
unique: List[ConfigurableFieldSpec] = []
for id, dupes in grouped:
first = next(dupes)

@ -0,0 +1,411 @@
from typing import Any, Callable, List, NamedTuple, Union
import pytest
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.runnables.context import Context
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import aadd, add
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
class TestCase(NamedTuple):
input: Any
output: Any
def seq_naive_rag() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
def seq_naive_rag_alt() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
def seq_naive_rag_scoped() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
scoped = Context.create_scope("a_scope")
return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
"scoped": scoped.setter("context") | scoped.getter("context"),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)
test_cases = [
(
Context.setter("foo") | Context.getter("foo"),
(
TestCase("foo", "foo"),
TestCase("bar", "bar"),
),
),
(
Context.setter("input") | {"bar": Context.getter("input")},
(
TestCase("foo", {"bar": "foo"}),
TestCase("bar", {"bar": "bar"}),
),
),
(
{"bar": Context.setter("input")} | Context.getter("input"),
(
TestCase("foo", "foo"),
TestCase("bar", "bar"),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
"prompt_str": Context.getter("prompt_str"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{
"response": "hello",
"prompt": StringPromptValue(text="foo bar"),
"prompt_str": "foo bar",
},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{
"response": "hello",
"prompt": StringPromptValue(text="bar foo"),
"prompt_str": "bar foo",
},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter(prompt_str=lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt_str", lambda x: x.to_string())
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt_str": Context.getter("prompt_str"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt_str": "foo bar"},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt_str": "bar foo"},
),
),
),
(
(
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
),
(
TestCase(
{"foo": "foo", "bar": "bar"},
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
),
TestCase(
{"foo": "bar", "bar": "foo"},
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
),
),
),
(
seq_naive_rag,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_alt,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
(
seq_naive_rag_scoped,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
]
@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
def test_runnable_context_seq_key_not_found() -> None:
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_order() -> None:
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_deadlock() -> None:
seq: Runnable = {
"bar": Context.setter("input") | Context.getter("foo"),
"foo": Context.setter("foo") | Context.getter("input"),
} | RunnablePassthrough()
with pytest.raises(ValueError):
seq.invoke("foo")
def test_runnable_context_seq_key_circular_ref() -> None:
seq: Runnable = {
"bar": Context.setter(input=Context.getter("input"))
} | Context.getter("foo")
with pytest.raises(ValueError):
seq.invoke("foo")
async def test_runnable_seq_streaming_chunks() -> None:
chain: Runnable = (
PromptTemplate.from_template("{foo} {bar}")
| Context.setter("prompt")
| FakeStreamingListLLM(responses=["hello"])
| StrOutputParser()
| {
"response": RunnablePassthrough(),
"prompt": Context.getter("prompt"),
}
)
chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})]
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
for c in achunks:
assert c in chunks
assert len(chunks) == 6
assert [c for c in chunks if c.get("response")] == [
{"response": "h"},
{"response": "e"},
{"response": "l"},
{"response": "l"},
{"response": "o"},
]
assert [c for c in chunks if c.get("prompt")] == [
{"prompt": StringPromptValue(text="foo bar")},
]

@ -2,6 +2,7 @@ from langchain_core.runnables import __all__
EXPECTED_ALL = [
"AddableDict",
"Context",
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",

Loading…
Cancel
Save