diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index c6609fc906..eb502b0899 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4892,6 +4892,23 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): config=self.config, ) + def __getattr__(self, name: str) -> Any: + attr = getattr(self.bound, name) + + if callable(attr) and accepts_config(attr): + + @wraps(attr) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return attr( + *args, + **kwargs, + config=merge_configs(self.config, kwargs.get("config")), + ) + + return wrapper + + return attr + RunnableLike = Union[ Runnable[Input, Output], diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 7deee3b54e..410cd976f9 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -3,6 +3,7 @@ from __future__ import annotations import enum import threading from abc import abstractmethod +from functools import wraps from typing import ( Any, AsyncIterator, @@ -26,6 +27,7 @@ from langchain_core.runnables.config import ( ensure_config, get_config_list, get_executor_for_config, + merge_configs, ) from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( @@ -46,6 +48,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): default: RunnableSerializable[Input, Output] + config: Optional[RunnableConfig] = None + class Config: arbitrary_types_allowed = True @@ -69,19 +73,37 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return runnable.get_input_schema(config) def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return runnable.get_output_schema(config) def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return runnable.get_graph(config) + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + return self.__class__( + **{**self.__dict__, "config": ensure_config(merge_configs(config, kwargs))} # type: ignore[arg-type] + ) + + def prepare( + self, config: Optional[RunnableConfig] = None + ) -> Tuple[Runnable[Input, Output], RunnableConfig]: + runnable: Runnable[Input, Output] = self + while isinstance(runnable, DynamicRunnable): + runnable, config = runnable._prepare(merge_configs(runnable.config, config)) + return runnable, cast(RunnableConfig, config) + @abstractmethod def _prepare( self, config: Optional[RunnableConfig] = None @@ -91,13 +113,13 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return runnable.invoke(input, config, **kwargs) async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return await runnable.ainvoke(input, config, **kwargs) def batch( @@ -109,7 +131,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): **kwargs: Optional[Any], ) -> List[Output]: configs = get_config_list(config, len(inputs)) - prepared = [self._prepare(c) for c in configs] + prepared = [self.prepare(c) for c in configs] if all(p is self.default for p, _ in prepared): return self.default.batch( @@ -151,7 +173,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): **kwargs: Optional[Any], ) -> List[Output]: configs = get_config_list(config, len(inputs)) - prepared = [self._prepare(c) for c in configs] + prepared = [self.prepare(c) for c in configs] if all(p is self.default for p, _ in prepared): return await self.default.abatch( @@ -186,7 +208,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return runnable.stream(input, config, **kwargs) async def astream( @@ -195,7 +217,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) async for chunk in runnable.astream(input, config, **kwargs): yield chunk @@ -205,7 +227,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) return runnable.transform(input, config, **kwargs) async def atransform( @@ -214,10 +236,48 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: - runnable, config = self._prepare(config) + runnable, config = self.prepare(config) async for chunk in runnable.atransform(input, config, **kwargs): yield chunk + def __getattr__(self, name: str) -> Any: + attr = getattr(self.default, name) + if callable(attr): + + @wraps(attr) + def wrapper(*args: Any, **kwargs: Any) -> Any: + for key, arg in kwargs.items(): + if key == "config" and ( + isinstance(arg, dict) + and "configurable" in arg + and isinstance(arg["configurable"], dict) + ): + runnable, config = self.prepare(cast(RunnableConfig, arg)) + kwargs = {**kwargs, "config": config} + return getattr(runnable, name)(*args, **kwargs) + + for idx, arg in enumerate(args): + if ( + isinstance(arg, dict) + and "configurable" in arg + and isinstance(arg["configurable"], dict) + ): + runnable, config = self.prepare(cast(RunnableConfig, arg)) + argsl = list(args) + argsl[idx] = config + return getattr(runnable, name)(*argsl, **kwargs) + + if self.config: + runnable, config = self.prepare() + return getattr(runnable, name)(*args, **kwargs) + + return attr(*args, **kwargs) + + return wrapper + + else: + return attr + class RunnableConfigurableFields(DynamicRunnable[Input, Output]): """Runnable that can be dynamically configured. @@ -291,19 +351,21 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( [ - ConfigurableFieldSpec( - id=spec.id, - name=spec.name, - description=spec.description - or self.default.__fields__[field_name].field_info.description, - annotation=spec.annotation - or self.default.__fields__[field_name].annotation, - default=getattr(self.default, field_name), - is_shared=spec.is_shared, - ) - if isinstance(spec, ConfigurableField) - else make_options_spec( - spec, self.default.__fields__[field_name].field_info.description + ( + ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description + or self.default.__fields__[field_name].field_info.description, + annotation=spec.annotation + or self.default.__fields__[field_name].annotation, + default=getattr(self.default, field_name), + is_shared=spec.is_shared, + ) + if isinstance(spec, ConfigurableField) + else make_options_spec( + spec, self.default.__fields__[field_name].field_info.description + ) ) for field_name, spec in self.fields.items() ] @@ -488,9 +550,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): ) # config specs of the alternatives + [ - prefix_config_spec(s, f"{self.which.id}=={alt_key}") - if self.prefix_keys - else s + ( + prefix_config_spec(s, f"{self.which.id}=={alt_key}") + if self.prefix_keys + else s + ) for alt_key, alt in self.alternatives.items() if isinstance(alt, RunnableSerializable) for s in alt.config_specs diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 0c3801b967..6fa048a3d8 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -1,5 +1,7 @@ from typing import Any, Dict, Optional +import pytest + from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables import ( ConfigurableField, @@ -29,6 +31,25 @@ class MyRunnable(RunnableSerializable[str, str]): def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: return input + self._my_hidden_property + def my_custom_function(self) -> str: + return self.my_property + + def my_custom_function_w_config(self, config: RunnableConfig) -> str: + return self.my_property + + +class MyOtherRunnable(RunnableSerializable[str, str]): + my_other_property: str + + def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: + return input + self.my_other_property + + def my_other_custom_function(self) -> str: + return self.my_other_property + + def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: + return self.my_other_property + def test_doubly_set_configurable() -> None: """Test that setting a configurable field with a default value works""" @@ -83,3 +104,86 @@ def test_field_alias_set_configurable() -> None: ) == "dc" ) + + +def test_config_passthrough() -> None: + runnable = MyRunnable(my_property="a") # type: ignore + configurable_runnable = runnable.configurable_fields( + my_property=ConfigurableField( + id="my_property", + name="My property", + description="The property to test", + ) + ) + # first one + with pytest.raises(AttributeError): + configurable_runnable.not_my_custom_function() # type: ignore[attr-defined] + + assert configurable_runnable.my_custom_function() == "a" # type: ignore[attr-defined] + assert ( + configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined] + {"configurable": {"my_property": "b"}} + ) + == "b" + ) + assert ( + configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "b"}} + ) + == "b" + ) + + # second one + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function() # type: ignore[attr-defined] + == "b" + ) + + +def test_config_passthrough_nested() -> None: + runnable = MyRunnable(my_property="a") # type: ignore + configurable_runnable = runnable.configurable_fields( + my_property=ConfigurableField( + id="my_property", + name="My property", + description="The property to test", + ) + ).configurable_alternatives( + ConfigurableField(id="which", description="Which runnable to use"), + other=MyOtherRunnable(my_other_property="c"), + ) + # first one + with pytest.raises(AttributeError): + configurable_runnable.not_my_custom_function() # type: ignore[attr-defined] + assert configurable_runnable.my_custom_function() == "a" # type: ignore[attr-defined] + assert ( + configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined] + {"configurable": {"my_property": "b"}} + ) + == "b" + ) + assert ( + configurable_runnable.my_custom_function_w_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "b"}} + ) + == "b" + ) + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function() # type: ignore[attr-defined] + == "b" + ) + # second one + with pytest.raises(AttributeError): + configurable_runnable.my_other_custom_function() # type: ignore[attr-defined] + with pytest.raises(AttributeError): + configurable_runnable.my_other_custom_function_w_config( # type: ignore[attr-defined] + {"configurable": {"my_other_property": "b"}} + ) + with pytest.raises(AttributeError): + configurable_runnable.with_config( + configurable={"my_other_property": "c", "which": "other"} + ).my_other_custom_function() # type: ignore[attr-defined]