mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
[core/minor] Runnables: Implement a context api (#14046)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Brace Sproul <braceasproul@gmail.com>
This commit is contained in:
parent
8f95a8206b
commit
77c38df36c
@ -30,6 +30,7 @@ from langchain_core.runnables.config import (
|
||||
get_config_list,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.context import Context
|
||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
||||
@ -47,6 +48,7 @@ __all__ = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"Context",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from copy import deepcopy
|
||||
from functools import partial, wraps
|
||||
from itertools import tee
|
||||
from itertools import groupby, tee
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -22,6 +22,7 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -1401,9 +1402,46 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.steps for spec in step.config_specs
|
||||
from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX, _key_from_id
|
||||
|
||||
# get all specs
|
||||
all_specs = [
|
||||
(spec, idx)
|
||||
for idx, step in enumerate(self.steps)
|
||||
for spec in step.config_specs
|
||||
]
|
||||
# calculate context dependencies
|
||||
specs_by_pos = groupby(
|
||||
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
|
||||
lambda x: x[1],
|
||||
)
|
||||
next_deps: Set[str] = set()
|
||||
deps_by_pos: Dict[int, Set[str]] = {}
|
||||
for pos, specs in specs_by_pos:
|
||||
deps_by_pos[pos] = next_deps
|
||||
next_deps = next_deps | {spec[0].id for spec in specs}
|
||||
# assign context dependencies
|
||||
for pos, (spec, idx) in enumerate(all_specs):
|
||||
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
|
||||
all_specs[pos] = (
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
annotation=spec.annotation,
|
||||
name=spec.name,
|
||||
default=spec.default,
|
||||
description=spec.description,
|
||||
is_shared=spec.is_shared,
|
||||
dependencies=[
|
||||
d
|
||||
for d in deps_by_pos[idx]
|
||||
if _key_from_id(d) != _key_from_id(spec.id)
|
||||
]
|
||||
+ (spec.dependencies or []),
|
||||
),
|
||||
idx,
|
||||
)
|
||||
|
||||
return get_unique_config_specs(spec for spec, _ in all_specs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "\n| ".join(
|
||||
@ -1456,8 +1494,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
from langchain_core.runnables.context import config_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
config = config_with_context(ensure_config(config), self.steps)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -1488,8 +1528,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
from langchain_core.runnables.context import aconfig_with_context
|
||||
|
||||
# setup callbacks and context
|
||||
config = aconfig_with_context(ensure_config(config), self.steps)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -1523,12 +1565,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.runnables.context import config_with_context
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
# setup callbacks and context
|
||||
configs = [
|
||||
config_with_context(c, self.steps)
|
||||
for c in get_config_list(config, len(inputs))
|
||||
]
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -1641,15 +1687,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
)
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
from langchain_core.runnables.context import aconfig_with_context
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
# setup callbacks and context
|
||||
configs = [
|
||||
aconfig_with_context(c, self.steps)
|
||||
for c in get_config_list(config, len(inputs))
|
||||
]
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -1763,7 +1811,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Iterator[Output]:
|
||||
from langchain_core.runnables.context import config_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
config = config_with_context(config, self.steps)
|
||||
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
@ -1787,7 +1838,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain_core.runnables.context import aconfig_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
config = aconfig_with_context(config, self.steps)
|
||||
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
|
@ -26,6 +26,10 @@ from langchain_core.runnables.config import (
|
||||
get_callback_manager_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.context import (
|
||||
CONTEXT_CONFIG_PREFIX,
|
||||
CONTEXT_CONFIG_SUFFIX_SET,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Input,
|
||||
@ -148,7 +152,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
specs = get_unique_config_specs(
|
||||
spec
|
||||
for step in (
|
||||
[self.default]
|
||||
@ -157,6 +161,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
for spec in step.config_specs
|
||||
)
|
||||
if any(
|
||||
s.id.startswith(CONTEXT_CONFIG_PREFIX)
|
||||
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
|
||||
for s in specs
|
||||
):
|
||||
raise ValueError("RunnableBranch cannot contain context setters.")
|
||||
return specs
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
|
313
libs/core/langchain_core/runnables/context.py
Normal file
313
libs/core/langchain_core/runnables/context.py
Normal file
@ -0,0 +1,313 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import RunnableConfig, patch_config
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
Values = Dict[Union[asyncio.Event, threading.Event], Any]
|
||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||
|
||||
|
||||
async def _asetter(done: asyncio.Event, values: Values, value: T) -> T:
|
||||
values[done] = value
|
||||
done.set()
|
||||
return value
|
||||
|
||||
|
||||
async def _agetter(done: asyncio.Event, values: Values) -> Any:
|
||||
await done.wait()
|
||||
return values[done]
|
||||
|
||||
|
||||
def _setter(done: threading.Event, values: Values, value: T) -> T:
|
||||
values[done] = value
|
||||
done.set()
|
||||
return value
|
||||
|
||||
|
||||
def _getter(done: threading.Event, values: Values) -> Any:
|
||||
done.wait()
|
||||
return values[done]
|
||||
|
||||
|
||||
def _key_from_id(id_: str) -> str:
|
||||
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
|
||||
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
|
||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
else:
|
||||
raise ValueError(f"Invalid context config id {id_}")
|
||||
|
||||
|
||||
def _config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
setter: Callable,
|
||||
getter: Callable,
|
||||
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
|
||||
) -> RunnableConfig:
|
||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||
return config
|
||||
|
||||
context_specs = [
|
||||
(spec, i)
|
||||
for i, step in enumerate(steps)
|
||||
for spec in step.config_specs
|
||||
if spec.id.startswith(CONTEXT_CONFIG_PREFIX)
|
||||
]
|
||||
grouped_by_key = {
|
||||
key: list(group)
|
||||
for key, group in groupby(
|
||||
sorted(context_specs, key=lambda s: s[0].id),
|
||||
key=lambda s: _key_from_id(s[0].id),
|
||||
)
|
||||
}
|
||||
deps_by_key = {
|
||||
key: set(
|
||||
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
|
||||
)
|
||||
for key, group in grouped_by_key.items()
|
||||
}
|
||||
|
||||
values: Values = {}
|
||||
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
event_cls
|
||||
)
|
||||
context_funcs: Dict[str, Callable[[], Any]] = {}
|
||||
for key, group in grouped_by_key.items():
|
||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
|
||||
for dep in deps_by_key[key]:
|
||||
if key in deps_by_key[dep]:
|
||||
raise ValueError(
|
||||
f"Deadlock detected between context keys {key} and {dep}"
|
||||
)
|
||||
if len(getters) < 1:
|
||||
raise ValueError(f"Expected at least one getter for context key {key}")
|
||||
if len(setters) != 1:
|
||||
raise ValueError(f"Expected exactly one setter for context key {key}")
|
||||
setter_idx = setters[0][1]
|
||||
if any(getter_idx < setter_idx for _, getter_idx in getters):
|
||||
raise ValueError(
|
||||
f"Context setter for key {key} must be defined after all getters."
|
||||
)
|
||||
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
|
||||
|
||||
return patch_config(config, configurable=context_funcs)
|
||||
|
||||
|
||||
def aconfig_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
) -> RunnableConfig:
|
||||
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
|
||||
|
||||
|
||||
def config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
) -> RunnableConfig:
|
||||
return _config_with_context(config, steps, _setter, _getter, threading.Event)
|
||||
|
||||
|
||||
class ContextGet(RunnableSerializable):
|
||||
prefix: str = ""
|
||||
|
||||
key: Union[str, List[str]]
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
|
||||
for k in keys
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
annotation=Callable[[], Any],
|
||||
)
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
config = config or {}
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
||||
else:
|
||||
return configurable[self.ids[0]]()
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
||||
return {key: value for key, value in zip(self.key, values)}
|
||||
else:
|
||||
return await configurable[self.ids[0]]()
|
||||
|
||||
|
||||
SetValue = Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Any,
|
||||
]
|
||||
|
||||
|
||||
def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
|
||||
if not isinstance(value, Runnable) and not callable(value):
|
||||
return coerce_to_runnable(lambda _: value)
|
||||
return coerce_to_runnable(value)
|
||||
|
||||
|
||||
class ContextSet(RunnableSerializable):
|
||||
prefix: str = ""
|
||||
|
||||
keys: Mapping[str, Optional[Runnable]]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
value: Optional[SetValue] = None,
|
||||
prefix: str = "",
|
||||
**kwargs: SetValue,
|
||||
):
|
||||
if key is not None:
|
||||
kwargs[key] = value
|
||||
super().__init__(
|
||||
keys={
|
||||
k: _coerce_set_value(v) if v is not None else None
|
||||
for k, v in kwargs.items()
|
||||
},
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||
for key in self.keys
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
mapper_config_specs = [
|
||||
s
|
||||
for mapper in self.keys.values()
|
||||
if mapper is not None
|
||||
for s in mapper.config_specs
|
||||
]
|
||||
for spec in mapper_config_specs:
|
||||
if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||
getter_key = spec.id.split("/")[1]
|
||||
if getter_key in self.keys:
|
||||
raise ValueError(
|
||||
f"Circular reference in context setter for key {getter_key}"
|
||||
)
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
annotation=Callable[[], Any],
|
||||
)
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
config = config or {}
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
configurable[id_](mapper.invoke(input, config))
|
||||
else:
|
||||
configurable[id_](input)
|
||||
return input
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
await configurable[id_](await mapper.ainvoke(input, config))
|
||||
else:
|
||||
await configurable[id_](input)
|
||||
return input
|
||||
|
||||
|
||||
class Context:
|
||||
@staticmethod
|
||||
def create_scope(scope: str, /) -> "PrefixContext":
|
||||
return PrefixContext(prefix=scope)
|
||||
|
||||
@staticmethod
|
||||
def getter(key: Union[str, List[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key)
|
||||
|
||||
@staticmethod
|
||||
def setter(
|
||||
_key: Optional[str] = None,
|
||||
_value: Optional[SetValue] = None,
|
||||
/,
|
||||
**kwargs: SetValue,
|
||||
) -> ContextSet:
|
||||
return ContextSet(_key, _value, prefix="", **kwargs)
|
||||
|
||||
|
||||
class PrefixContext:
|
||||
prefix: str = ""
|
||||
|
||||
def __init__(self, prefix: str = ""):
|
||||
self.prefix = prefix
|
||||
|
||||
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key, prefix=self.prefix)
|
||||
|
||||
def setter(
|
||||
self,
|
||||
_key: Optional[str] = None,
|
||||
_value: Optional[SetValue] = None,
|
||||
/,
|
||||
**kwargs: SetValue,
|
||||
) -> ContextSet:
|
||||
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)
|
@ -308,13 +308,16 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
description: Optional[str] = None
|
||||
default: Any = None
|
||||
is_shared: bool = False
|
||||
dependencies: Optional[List[str]] = None
|
||||
|
||||
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs."""
|
||||
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
||||
grouped = groupby(
|
||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||
)
|
||||
unique: List[ConfigurableFieldSpec] = []
|
||||
for id, dupes in grouped:
|
||||
first = next(dupes)
|
||||
|
411
libs/core/tests/unit_tests/runnables/test_context.py
Normal file
411
libs/core/tests/unit_tests/runnables/test_context.py
Normal file
@ -0,0 +1,411 @@
|
||||
from typing import Any, Callable, List, NamedTuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompt_values import StringPromptValue
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||
from langchain_core.runnables.context import Context
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import aadd, add
|
||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||
|
||||
|
||||
class TestCase(NamedTuple):
|
||||
input: Any
|
||||
output: Any
|
||||
|
||||
|
||||
def seq_naive_rag() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"result": RunnablePassthrough(),
|
||||
"context": Context.getter("context"),
|
||||
"input": Context.getter("input"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def seq_naive_rag_alt() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| Context.setter("result")
|
||||
| Context.getter(["context", "input", "result"])
|
||||
)
|
||||
|
||||
|
||||
def seq_naive_rag_scoped() -> Runnable:
|
||||
context = [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
scoped = Context.create_scope("a_scope")
|
||||
|
||||
return (
|
||||
Context.setter("input")
|
||||
| {
|
||||
"context": retriever | Context.setter("context"),
|
||||
"question": RunnablePassthrough(),
|
||||
"scoped": scoped.setter("context") | scoped.getter("context"),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| Context.setter("result")
|
||||
| Context.getter(["context", "input", "result"])
|
||||
)
|
||||
|
||||
|
||||
test_cases = [
|
||||
(
|
||||
Context.setter("foo") | Context.getter("foo"),
|
||||
(
|
||||
TestCase("foo", "foo"),
|
||||
TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
Context.setter("input") | {"bar": Context.getter("input")},
|
||||
(
|
||||
TestCase("foo", {"bar": "foo"}),
|
||||
TestCase("bar", {"bar": "bar"}),
|
||||
),
|
||||
),
|
||||
(
|
||||
{"bar": Context.setter("input")} | Context.getter("input"),
|
||||
(
|
||||
TestCase("foo", "foo"),
|
||||
TestCase("bar", "bar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt", prompt_str=lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{
|
||||
"response": "hello",
|
||||
"prompt": StringPromptValue(text="foo bar"),
|
||||
"prompt_str": "foo bar",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{
|
||||
"response": "hello",
|
||||
"prompt": StringPromptValue(text="bar foo"),
|
||||
"prompt_str": "bar foo",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter(prompt_str=lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt_str", lambda x: x.to_string())
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt_str": Context.getter("prompt_str"),
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt_str": "foo bar"},
|
||||
),
|
||||
TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt_str": "bar foo"},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeStreamingListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
),
|
||||
(
|
||||
TestCase(
|
||||
{"foo": "foo", "bar": "bar"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="foo bar")},
|
||||
),
|
||||
TestCase(
|
||||
{"foo": "bar", "bar": "foo"},
|
||||
{"response": "hello", "prompt": StringPromptValue(text="bar foo")},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag,
|
||||
(
|
||||
TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag_alt,
|
||||
(
|
||||
TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
seq_naive_rag_scoped,
|
||||
(
|
||||
TestCase(
|
||||
"What up",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "What up",
|
||||
},
|
||||
),
|
||||
TestCase(
|
||||
"Howdy",
|
||||
{
|
||||
"result": "hello",
|
||||
"context": [
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"What's your name?",
|
||||
],
|
||||
"input": "Howdy",
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runnable, cases", test_cases)
|
||||
async def test_context_runnables(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase]
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert runnable.invoke(cases[0].input) == cases[0].output
|
||||
assert await runnable.ainvoke(cases[1].input) == cases[1].output
|
||||
assert runnable.batch([case.input for case in cases]) == [
|
||||
case.output for case in cases
|
||||
]
|
||||
assert await runnable.abatch([case.input for case in cases]) == [
|
||||
case.output for case in cases
|
||||
]
|
||||
assert add(runnable.stream(cases[0].input)) == cases[0].output
|
||||
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_not_found() -> None:
|
||||
seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_order() -> None:
|
||||
seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_deadlock() -> None:
|
||||
seq: Runnable = {
|
||||
"bar": Context.setter("input") | Context.getter("foo"),
|
||||
"foo": Context.setter("foo") | Context.getter("input"),
|
||||
} | RunnablePassthrough()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
def test_runnable_context_seq_key_circular_ref() -> None:
|
||||
seq: Runnable = {
|
||||
"bar": Context.setter(input=Context.getter("input"))
|
||||
} | Context.getter("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
seq.invoke("foo")
|
||||
|
||||
|
||||
async def test_runnable_seq_streaming_chunks() -> None:
|
||||
chain: Runnable = (
|
||||
PromptTemplate.from_template("{foo} {bar}")
|
||||
| Context.setter("prompt")
|
||||
| FakeStreamingListLLM(responses=["hello"])
|
||||
| StrOutputParser()
|
||||
| {
|
||||
"response": RunnablePassthrough(),
|
||||
"prompt": Context.getter("prompt"),
|
||||
}
|
||||
)
|
||||
|
||||
chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})]
|
||||
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
|
||||
for c in chunks:
|
||||
assert c in achunks
|
||||
for c in achunks:
|
||||
assert c in chunks
|
||||
|
||||
assert len(chunks) == 6
|
||||
assert [c for c in chunks if c.get("response")] == [
|
||||
{"response": "h"},
|
||||
{"response": "e"},
|
||||
{"response": "l"},
|
||||
{"response": "l"},
|
||||
{"response": "o"},
|
||||
]
|
||||
assert [c for c in chunks if c.get("prompt")] == [
|
||||
{"prompt": StringPromptValue(text="foo bar")},
|
||||
]
|
@ -2,6 +2,7 @@ from langchain_core.runnables import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AddableDict",
|
||||
"Context",
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
|
Loading…
Reference in New Issue
Block a user