Remove GetLocal, PutLocal (#12133)

Do you agree?
pull/12168/head^2
Nuno Campos 12 months ago committed by GitHub
parent 8c150ad7f6
commit 34ffb94770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,7 +14,6 @@ creating more responsive UX.
This module contains schema and implementation of LangChain Runnables primitives. This module contains schema and implementation of LangChain Runnables primitives.
""" """
from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import (
Runnable, Runnable,
RunnableBinding, RunnableBinding,
@ -40,9 +39,7 @@ __all__ = [
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",
"GetLocalVar",
"patch_config", "patch_config",
"PutLocalVar",
"RouterInput", "RouterInput",
"RouterRunnable", "RouterRunnable",
"Runnable", "Runnable",

@ -1,168 +0,0 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
Mapping,
Optional,
Union,
)
from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]]
"""The key(s) to use for storing the input variable(s) in local state.
If a string is provided then the entire input is stored under that key. If a
Mapping is provided, then the map values are gotten from the input and
stored in local state under the map keys.
"""
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _concat_put(
self,
input: Other,
*,
config: Optional[RunnableConfig] = None,
replace: bool = False,
) -> None:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
if self.key not in config["locals"] or replace:
config["locals"][self.key] = input
else:
config["locals"][self.key] += input
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
if put_key not in config["locals"] or replace:
config["locals"][put_key] = input[input_key]
else:
config["locals"][put_key] += input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
f"{(type(self.key))}."
)
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Other:
self._concat_put(input, config=config, replace=True)
return super().invoke(input, config=config, **kwargs)
async def ainvoke(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Other:
self._concat_put(input, config=config, replace=True)
return await super().ainvoke(input, config=config, **kwargs)
def transform(
self,
input: Iterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Other]:
for chunk in super().transform(input, config=config, **kwargs):
self._concat_put(chunk, config=config)
yield chunk
async def atransform(
self,
input: AsyncIterator[Other],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Other]:
async for chunk in super().atransform(input, config=config, **kwargs):
self._concat_put(chunk, config=config)
yield chunk
class GetLocalVar(
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
):
key: str
"""The key to extract from the local state."""
passthrough_key: Optional[str] = None
"""The key to use for passing through the invocation input.
If None, then only the value retrieved from local state is returned. Otherwise a
dictionary ``{self.key: <<retrieved_value>>, self.passthrough_key: <<input>>}``
is returned.
"""
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _get(
self,
input: Input,
run_manager: Union[CallbackManagerForChainRun, Any],
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key:
return {
self.key: config["locals"][self.key],
self.passthrough_key: input,
}
else:
return config["locals"][self.key]
async def _aget(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
return self._get(input, run_manager, config)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"GetLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
return self._call_with_config(self._get, input, config)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"GetLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
return await self._acall_with_config(self._aget, input, config)

@ -1656,7 +1656,6 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
# mark each step as a child run # mark each step as a child run
patch_config( patch_config(
config, config,
copy_locals=True,
callbacks=run_manager.get_child(f"map:key:{key}"), callbacks=run_manager.get_child(f"map:key:{key}"),
), ),
) )
@ -2534,10 +2533,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
[merge_configs(self.config, conf) for conf in config], [merge_configs(self.config, conf) for conf in config],
) )
else: else:
configs = [ configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch( return self.bound.batch(
inputs, inputs,
configs, configs,
@ -2559,10 +2555,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
[merge_configs(self.config, conf) for conf in config], [merge_configs(self.config, conf) for conf in config],
) )
else: else:
configs = [ configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch( return await self.bound.abatch(
inputs, inputs,
configs, configs,

@ -64,13 +64,6 @@ class RunnableConfig(TypedDict, total=False):
Name for the tracer run for this call. Defaults to the name of the class. Name for the tracer run for this call. Defaults to the name of the class.
""" """
locals: Dict[str, Any]
"""
Variables scoped to this call and any sub-calls. Usually used with
GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable
objects in locals, as they will be shared between parallel sub-calls.
"""
max_concurrency: Optional[int] max_concurrency: Optional[int]
""" """
Maximum number of parallel calls to make. If not provided, defaults to Maximum number of parallel calls to make. If not provided, defaults to
@ -96,7 +89,6 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
tags=[], tags=[],
metadata={}, metadata={},
callbacks=None, callbacks=None,
locals={},
recursion_limit=25, recursion_limit=25,
) )
if config is not None: if config is not None:
@ -124,14 +116,13 @@ def get_config_list(
return ( return (
list(map(ensure_config, config)) list(map(ensure_config, config))
if isinstance(config, list) if isinstance(config, list)
else [patch_config(config, copy_locals=True) for _ in range(length)] else [ensure_config(config) for _ in range(length)]
) )
def patch_config( def patch_config(
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
*, *,
copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None, callbacks: Optional[BaseCallbackManager] = None,
recursion_limit: Optional[int] = None, recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
@ -139,8 +130,6 @@ def patch_config(
configurable: Optional[Dict[str, Any]] = None, configurable: Optional[Dict[str, Any]] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) config = ensure_config(config)
if copy_locals:
config["locals"] = config["locals"].copy()
if callbacks is not None: if callbacks is not None:
# If we're replacing callbacks we need to unset run_name # If we're replacing callbacks we need to unset run_name
# As that should apply only to the same run as the original callbacks # As that should apply only to the same run as the original callbacks

@ -1,94 +0,0 @@
from typing import Any, Callable, Type
import pytest
from langchain.llms import FakeListLLM
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
Runnable,
RunnablePassthrough,
RunnableSequence,
)
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.invoke(x), "foo", "foo"),
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
],
)
def test_put_get(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert method(runnable, input) == output
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.ainvoke(x), "foo", "foo"),
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
],
)
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert await method(runnable, input) == output
@pytest.mark.parametrize(
("runnable", "error"),
[
(PutLocalVar("input"), ValueError),
(GetLocalVar("input"), ValueError),
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
],
)
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
with pytest.raises(error):
runnable.invoke("foo")
def test_get_in_map() -> None:
runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")}
assert runnable.invoke("foo") == {"bar": "foo"}
def test_put_in_map() -> None:
runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")
@pytest.mark.parametrize(
"runnable",
[
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
(
PutLocalVar("input")
| {"input": RunnablePassthrough()}
| PromptTemplate.from_template("say {input}")
| FakeListLLM(responses=["hello"])
| GetLocalVar("input", passthrough_key="output")
),
],
)
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}),
(lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]),
(
lambda r, x: list(r.stream(x))[0],
"hello",
{"input": "hello", "output": "hello"},
),
],
)
def test_put_get_sequence(
runnable: RunnableSequence, method: Callable, input: Any, output: Any
) -> None:
assert method(runnable, input) == output

@ -1209,7 +1209,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
locals={},
recursion_limit=5, recursion_limit=5,
), ),
), ),
@ -1219,7 +1218,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
locals={},
recursion_limit=5, recursion_limit=5,
), ),
), ),
@ -1290,7 +1288,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=[], tags=[],
callbacks=None, callbacks=None,
locals={},
recursion_limit=25, recursion_limit=25,
), ),
), ),
@ -1300,7 +1297,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=[], tags=[],
callbacks=None, callbacks=None,
locals={},
recursion_limit=25, recursion_limit=25,
), ),
), ),

Loading…
Cancel
Save