From bbb609ac9da8d6a6b2b28813a2c592356c8101eb Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 8 Mar 2024 17:35:13 -0800 Subject: [PATCH] core[patch]: fix arbitrary config keys (#18827) --- libs/core/langchain_core/runnables/base.py | 4 ++++ libs/core/tests/unit_tests/runnables/test_config.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 72975d6acc..8ad86b047d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4021,6 +4021,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): custom_output_type=custom_output_type, **other_kwargs, ) + # if we don't explicitly set config to the TypedDict here, + # the pydantic init above will strip out any of the "extra" + # fields even though total=False on the typed dict. + self.config = config or {} def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index 4710b2d525..5ea4a4c58e 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -1,6 +1,9 @@ +from typing import Any, cast + from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_core.runnables import RunnableBinding, RunnablePassthrough from langchain_core.runnables.config import RunnableConfig, merge_configs from langchain_core.tracers.stdout import ConsoleCallbackHandler @@ -32,3 +35,11 @@ def test_merge_config_callbacks() -> None: assert len(merged) == 2 assert isinstance(merged[0], ConsoleCallbackHandler) assert isinstance(merged[1], StreamingStdOutCallbackHandler) + + +def test_config_arbitrary_keys() -> None: + base: RunnablePassthrough[Any] = RunnablePassthrough() + bound = base.with_config(my_custom_key="my custom value") + config = cast(RunnableBinding, bound).config + + assert config.get("my_custom_key") == "my custom value"