diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 4650c2dab6..d969e4ab6e 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2035,9 +2035,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): copy = cast(RunnableConfig, dict(self.config)) if config: for key in config: - # Even though the keys aren't literals this is correct - # because both dicts are same type - copy[key] = config[key] or copy.get(key) # type: ignore + if key == "metadata": + copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore + elif key == "tags": + copy[key] = (copy.get(key) or []) + config[key] # type: ignore + else: + # Even though the keys aren't literals this is correct + # because both dicts are same type + copy[key] = config[key] or copy.get(key) # type: ignore return copy def bind(self, **kwargs: Any) -> Runnable[Input, Output]: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 61f74eb22b..2a4bd2074e 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -7,7 +7,7 @@ from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion -from langchain.callbacks.manager import Callbacks +from langchain.callbacks.manager import Callbacks, collect_runs from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler @@ -1271,6 +1271,31 @@ def test_with_config_with_config() -> None: ) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]})) +def test_metadata_is_merged() -> None: + """Test metadata and tags defined in with_config and at are merged/concatend.""" + + foo = RunnableLambda(lambda x: x).with_config({"metadata": {"my_key": "my_value"}}) + expected_metadata = { + "my_key": "my_value", + "my_other_key": "my_other_value", + } + with collect_runs() as cb: + foo.invoke("hi", {"metadata": {"my_other_key": "my_other_value"}}) + run = cb.traced_runs[0] + assert run.extra["metadata"] == expected_metadata + + +def test_tags_are_appended() -> None: + """Test tags from with_config are concatenated with those in invocation.""" + + foo = RunnableLambda(lambda x: x).with_config({"tags": ["my_key"]}) + with collect_runs() as cb: + foo.invoke("hi", {"tags": ["invoked_key"]}) + run = cb.traced_runs[0] + assert isinstance(run.tags, list) + assert sorted(run.tags) == sorted(["my_key", "invoked_key"]) + + def test_bind_bind() -> None: llm = FakeListLLM(responses=["i'm a textbot"])