runnable.bind().bind() should combine kwargs, instead of nesting wrappers (#8467)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos 2023-07-29 23:48:30 +01:00 committed by GitHub
parent ae4638aa35
commit b65a9414bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View File

@ -714,6 +714,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self.bound.invoke(input, config, **self.kwargs) return self.bound.invoke(input, config, **self.kwargs)

View File

@ -11,7 +11,7 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM from langchain.llms.fake import FakeListLLM
from langchain.load.dump import dumps from langchain.load.dump import dumpd, dumps
from langchain.output_parsers.list import CommaSeparatedListOutputParser from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.chat import ( from langchain.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
@ -609,3 +609,13 @@ def test_seq_prompt_map(
] ]
) )
assert tracer.runs == snapshot assert tracer.runs == snapshot
def test_bind_bind() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])
assert dumpd(
llm.bind(stop=["Thought:"], one="two").bind(
stop=["Observation:"], hello="world"
)
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))