Merge metadata + tags in config (#10762)

Think these should be a merge/update rather than overwrite
pull/10785/head
William FH 11 months ago committed by GitHub
parent 71025013f8
commit c8f386db97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2035,9 +2035,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
if key == "metadata":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:

@ -7,7 +7,7 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.manager import Callbacks, collect_runs
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
@ -1271,6 +1271,31 @@ def test_with_config_with_config() -> None:
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
def test_metadata_is_merged() -> None:
"""Test metadata and tags defined in with_config and at are merged/concatend."""
foo = RunnableLambda(lambda x: x).with_config({"metadata": {"my_key": "my_value"}})
expected_metadata = {
"my_key": "my_value",
"my_other_key": "my_other_value",
}
with collect_runs() as cb:
foo.invoke("hi", {"metadata": {"my_other_key": "my_other_value"}})
run = cb.traced_runs[0]
assert run.extra["metadata"] == expected_metadata
def test_tags_are_appended() -> None:
"""Test tags from with_config are concatenated with those in invocation."""
foo = RunnableLambda(lambda x: x).with_config({"tags": ["my_key"]})
with collect_runs() as cb:
foo.invoke("hi", {"tags": ["invoked_key"]})
run = cb.traced_runs[0]
assert isinstance(run.tags, list)
assert sorted(run.tags) == sorted(["my_key", "invoked_key"])
def test_bind_bind() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])

Loading…
Cancel
Save