Add Runnable.bind method to attach kwargs to a Runnable that will be passed to all invoke/stream/batch calls when it is run (#8368)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
pull/8371/head
Nuno Campos 1 year ago committed by GitHub
parent cf608f876b
commit 0eca3e7d90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -131,6 +131,12 @@ class Runnable(Generic[Input, Output], ABC):
) -> AsyncIterator[Output]:
yield await self.ainvoke(input, config)
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""
Bind arguments to a Runnable, returning a new Runnable.
"""
return RunnableBinding(bound=self, kwargs=kwargs)
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
@ -692,6 +698,60 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
return self._call_with_config(lambda x: x, input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]):
bound: Runnable[Input, Output]
kwargs: Mapping[str, Any]
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self.bound.invoke(input, config, **self.kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return await self.bound.ainvoke(input, config, **self.kwargs)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
return self.bound.batch(
inputs, config, max_concurrency=max_concurrency, **self.kwargs
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
return await self.bound.abatch(
inputs, config, max_concurrency=max_concurrency, **self.kwargs
)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield from self.bound.stream(input, config, **self.kwargs)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
async for item in self.bound.astream(input, config, **self.kwargs):
yield item
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:

File diff suppressed because one or more lines are too long

@ -566,7 +566,7 @@ def test_seq_prompt_map(
prompt
| passthrough
| {
"chat": chat,
"chat": chat.bind(stop=["Thought:"]),
"llm": llm,
"passthrough": passthrough,
}

Loading…
Cancel
Save