core: forward config params to default (#20402)

nuno's fault not mine

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
pull/20490/head
Erick Friis 6 months ago committed by GitHub
parent 97b2191e99
commit 7997f3b7f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -4892,6 +4892,23 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
config=self.config, config=self.config,
) )
def __getattr__(self, name: str) -> Any:
attr = getattr(self.bound, name)
if callable(attr) and accepts_config(attr):
@wraps(attr)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return attr(
*args,
**kwargs,
config=merge_configs(self.config, kwargs.get("config")),
)
return wrapper
return attr
RunnableLike = Union[ RunnableLike = Union[
Runnable[Input, Output], Runnable[Input, Output],

@ -3,6 +3,7 @@ from __future__ import annotations
import enum import enum
import threading import threading
from abc import abstractmethod from abc import abstractmethod
from functools import wraps
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -26,6 +27,7 @@ from langchain_core.runnables.config import (
ensure_config, ensure_config,
get_config_list, get_config_list,
get_executor_for_config, get_executor_for_config,
merge_configs,
) )
from langchain_core.runnables.graph import Graph from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@ -46,6 +48,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
default: RunnableSerializable[Input, Output] default: RunnableSerializable[Input, Output]
config: Optional[RunnableConfig] = None
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -69,19 +73,37 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return runnable.get_input_schema(config) return runnable.get_input_schema(config)
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return runnable.get_output_schema(config) return runnable.get_output_schema(config)
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return runnable.get_graph(config) return runnable.get_graph(config)
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
**{**self.__dict__, "config": ensure_config(merge_configs(config, kwargs))} # type: ignore[arg-type]
)
def prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
runnable: Runnable[Input, Output] = self
while isinstance(runnable, DynamicRunnable):
runnable, config = runnable._prepare(merge_configs(runnable.config, config))
return runnable, cast(RunnableConfig, config)
@abstractmethod @abstractmethod
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
@ -91,13 +113,13 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return runnable.invoke(input, config, **kwargs) return runnable.invoke(input, config, **kwargs)
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return await runnable.ainvoke(input, config, **kwargs) return await runnable.ainvoke(input, config, **kwargs)
def batch( def batch(
@ -109,7 +131,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs] prepared = [self.prepare(c) for c in configs]
if all(p is self.default for p, _ in prepared): if all(p is self.default for p, _ in prepared):
return self.default.batch( return self.default.batch(
@ -151,7 +173,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs] prepared = [self.prepare(c) for c in configs]
if all(p is self.default for p, _ in prepared): if all(p is self.default for p, _ in prepared):
return await self.default.abatch( return await self.default.abatch(
@ -186,7 +208,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return runnable.stream(input, config, **kwargs) return runnable.stream(input, config, **kwargs)
async def astream( async def astream(
@ -195,7 +217,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
async for chunk in runnable.astream(input, config, **kwargs): async for chunk in runnable.astream(input, config, **kwargs):
yield chunk yield chunk
@ -205,7 +227,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
return runnable.transform(input, config, **kwargs) return runnable.transform(input, config, **kwargs)
async def atransform( async def atransform(
@ -214,10 +236,48 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
runnable, config = self._prepare(config) runnable, config = self.prepare(config)
async for chunk in runnable.atransform(input, config, **kwargs): async for chunk in runnable.atransform(input, config, **kwargs):
yield chunk yield chunk
def __getattr__(self, name: str) -> Any:
attr = getattr(self.default, name)
if callable(attr):
@wraps(attr)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for key, arg in kwargs.items():
if key == "config" and (
isinstance(arg, dict)
and "configurable" in arg
and isinstance(arg["configurable"], dict)
):
runnable, config = self.prepare(cast(RunnableConfig, arg))
kwargs = {**kwargs, "config": config}
return getattr(runnable, name)(*args, **kwargs)
for idx, arg in enumerate(args):
if (
isinstance(arg, dict)
and "configurable" in arg
and isinstance(arg["configurable"], dict)
):
runnable, config = self.prepare(cast(RunnableConfig, arg))
argsl = list(args)
argsl[idx] = config
return getattr(runnable, name)(*argsl, **kwargs)
if self.config:
runnable, config = self.prepare()
return getattr(runnable, name)(*args, **kwargs)
return attr(*args, **kwargs)
return wrapper
else:
return attr
class RunnableConfigurableFields(DynamicRunnable[Input, Output]): class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
"""Runnable that can be dynamically configured. """Runnable that can be dynamically configured.
@ -291,19 +351,21 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs( return get_unique_config_specs(
[ [
ConfigurableFieldSpec( (
id=spec.id, ConfigurableFieldSpec(
name=spec.name, id=spec.id,
description=spec.description name=spec.name,
or self.default.__fields__[field_name].field_info.description, description=spec.description
annotation=spec.annotation or self.default.__fields__[field_name].field_info.description,
or self.default.__fields__[field_name].annotation, annotation=spec.annotation
default=getattr(self.default, field_name), or self.default.__fields__[field_name].annotation,
is_shared=spec.is_shared, default=getattr(self.default, field_name),
) is_shared=spec.is_shared,
if isinstance(spec, ConfigurableField) )
else make_options_spec( if isinstance(spec, ConfigurableField)
spec, self.default.__fields__[field_name].field_info.description else make_options_spec(
spec, self.default.__fields__[field_name].field_info.description
)
) )
for field_name, spec in self.fields.items() for field_name, spec in self.fields.items()
] ]
@ -488,9 +550,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
) )
# config specs of the alternatives # config specs of the alternatives
+ [ + [
prefix_config_spec(s, f"{self.which.id}=={alt_key}") (
if self.prefix_keys prefix_config_spec(s, f"{self.which.id}=={alt_key}")
else s if self.prefix_keys
else s
)
for alt_key, alt in self.alternatives.items() for alt_key, alt in self.alternatives.items()
if isinstance(alt, RunnableSerializable) if isinstance(alt, RunnableSerializable)
for s in alt.config_specs for s in alt.config_specs

@ -1,5 +1,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import pytest
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import ( from langchain_core.runnables import (
ConfigurableField, ConfigurableField,
@ -29,6 +31,25 @@ class MyRunnable(RunnableSerializable[str, str]):
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
return input + self._my_hidden_property return input + self._my_hidden_property
def my_custom_function(self) -> str:
return self.my_property
def my_custom_function_w_config(self, config: RunnableConfig) -> str:
return self.my_property
class MyOtherRunnable(RunnableSerializable[str, str]):
my_other_property: str
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
return input + self.my_other_property
def my_other_custom_function(self) -> str:
return self.my_other_property
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str:
return self.my_other_property
def test_doubly_set_configurable() -> None: def test_doubly_set_configurable() -> None:
"""Test that setting a configurable field with a default value works""" """Test that setting a configurable field with a default value works"""
@ -83,3 +104,86 @@ def test_field_alias_set_configurable() -> None:
) )
== "dc" == "dc"
) )
def test_config_passthrough() -> None:
runnable = MyRunnable(my_property="a") # type: ignore
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property",
name="My property",
description="The property to test",
)
)
# first one
with pytest.raises(AttributeError):
configurable_runnable.not_my_custom_function() # type: ignore[attr-defined]
assert configurable_runnable.my_custom_function() == "a" # type: ignore[attr-defined]
assert (
configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined]
{"configurable": {"my_property": "b"}}
)
== "b"
)
assert (
configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined]
config={"configurable": {"my_property": "b"}}
)
== "b"
)
# second one
assert (
configurable_runnable.with_config(
configurable={"my_property": "b"}
).my_custom_function() # type: ignore[attr-defined]
== "b"
)
def test_config_passthrough_nested() -> None:
runnable = MyRunnable(my_property="a") # type: ignore
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property",
name="My property",
description="The property to test",
)
).configurable_alternatives(
ConfigurableField(id="which", description="Which runnable to use"),
other=MyOtherRunnable(my_other_property="c"),
)
# first one
with pytest.raises(AttributeError):
configurable_runnable.not_my_custom_function() # type: ignore[attr-defined]
assert configurable_runnable.my_custom_function() == "a" # type: ignore[attr-defined]
assert (
configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined]
{"configurable": {"my_property": "b"}}
)
== "b"
)
assert (
configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined]
config={"configurable": {"my_property": "b"}}
)
== "b"
)
assert (
configurable_runnable.with_config(
configurable={"my_property": "b"}
).my_custom_function() # type: ignore[attr-defined]
== "b"
)
# second one
with pytest.raises(AttributeError):
configurable_runnable.my_other_custom_function() # type: ignore[attr-defined]
with pytest.raises(AttributeError):
configurable_runnable.my_other_custom_function_w_config( # type: ignore[attr-defined]
{"configurable": {"my_other_property": "b"}}
)
with pytest.raises(AttributeError):
configurable_runnable.with_config(
configurable={"my_other_property": "c", "which": "other"}
).my_other_custom_function() # type: ignore[attr-defined]

Loading…
Cancel
Save