diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 72975d6acc..8ad86b047d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4021,6 +4021,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): custom_output_type=custom_output_type, **other_kwargs, ) + # if we don't explicitly set config to the TypedDict here, + # the pydantic init above will strip out any of the "extra" + # fields even though total=False on the typed dict. + self.config = config or {} def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index 4710b2d525..5ea4a4c58e 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -1,6 +1,9 @@ +from typing import Any, cast + from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_core.runnables import RunnableBinding, RunnablePassthrough from langchain_core.runnables.config import RunnableConfig, merge_configs from langchain_core.tracers.stdout import ConsoleCallbackHandler @@ -32,3 +35,11 @@ def test_merge_config_callbacks() -> None: assert len(merged) == 2 assert isinstance(merged[0], ConsoleCallbackHandler) assert isinstance(merged[1], StreamingStdOutCallbackHandler) + + +def test_config_arbitrary_keys() -> None: + base: RunnablePassthrough[Any] = RunnablePassthrough() + bound = base.with_config(my_custom_key="my custom value") + config = cast(RunnableBinding, bound).config + + assert config.get("my_custom_key") == "my custom value"