From b65a9414bb9b549f54f7ce917d56879fea9c2fdb Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 29 Jul 2023 23:48:30 +0100 Subject: [PATCH] runnable.bind().bind() should combine kwargs, instead of nesting wrappers (#8467) --------- Co-authored-by: Harrison Chase --- libs/langchain/langchain/schema/runnable.py | 3 +++ .../tests/unit_tests/schema/test_runnable.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 0864fe55d0..09f76de492 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -714,6 +714,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def lc_serializable(self) -> bool: return True + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: + return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: return self.bound.invoke(input, config, **self.kwargs) diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py index 23fda63f45..02104e8b4b 100644 --- a/libs/langchain/tests/unit_tests/schema/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -11,7 +11,7 @@ from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run from langchain.chat_models.fake import FakeListChatModel from langchain.llms.fake import FakeListLLM -from langchain.load.dump import dumps +from langchain.load.dump import dumpd, dumps from langchain.output_parsers.list import CommaSeparatedListOutputParser from langchain.prompts.chat import ( ChatPromptTemplate, @@ -609,3 +609,13 @@ def test_seq_prompt_map( ] ) assert tracer.runs == snapshot + + +def test_bind_bind() -> None: + llm = FakeListLLM(responses=["i'm a textbot"]) + + assert dumpd( + llm.bind(stop=["Thought:"], one="two").bind( + stop=["Observation:"], hello="world" + ) + ) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))