This commit is contained in:
Bagatur 2023-08-09 13:26:09 -07:00
parent 539672a7fd
commit 50b13ab938
3 changed files with 60 additions and 114 deletions

View File

@ -62,7 +62,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self(input, **(config or {}), **kwargs)
_config: Dict[str, Any] = dict(config) if config else {}
_config.pop("_locals", None)
return self(input, **_config, **kwargs)
async def ainvoke(
self,
@ -76,7 +78,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
None, partial(self.invoke, input, config, **kwargs)
)
return await self.acall(input, **(config or {}), **kwargs)
_config: Dict[str, Any] = dict(config) if config else {}
_config.pop("_locals", None)
return await self.acall(input, **_config, **kwargs)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from itertools import tee
from typing import (
Any,
@ -66,6 +67,35 @@ class RunnableConfig(TypedDict, total=False):
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
_locals: Dict[str, Any]
"""
Local variables
"""
def _empty_config() -> RunnableConfig:
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
def _get_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def _get_async_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
@ -243,7 +273,7 @@ class Runnable(Generic[Input, Output], ABC):
return (
config
if isinstance(config, list)
else [config.copy() if config is not None else {} for _ in range(length)]
else [deepcopy(config) if config is not None else {} for _ in range(length)]
)
def _call_with_config(
@ -255,14 +285,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
from langchain.callbacks.manager import CallbackManager
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
@ -288,14 +312,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
@ -322,8 +340,6 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
from langchain.callbacks.manager import CallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
@ -333,11 +349,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output_supported = True
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
@ -393,8 +405,6 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
@ -404,11 +414,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output_supported = True
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
@ -473,19 +479,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
yield from self.fallbacks
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -516,19 +512,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -751,19 +737,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
config = config or _empty_config()
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -771,11 +747,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke all steps in sequence
try:
callbacks = run_manager.get_child()
for step in self.steps:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
_patch_config(config, callbacks),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -790,19 +767,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -946,19 +913,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -1023,19 +980,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -1173,19 +1120,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input}
@ -1464,10 +1401,11 @@ class RouterRunnable(
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None
) -> RunnableConfig:
config = config.copy()
config = deepcopy(config)
config["callbacks"] = callback_manager
config["_locals"] = _locals or {}
return config

View File

@ -636,7 +636,9 @@ async def _arun_chain(
else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, _locals={}
)
output = await chain.ainvoke(inputs_, config=runnable_config)
return output
@ -957,7 +959,9 @@ def _run_chain(
else:
output = chain(inputs_, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, _locals={}
)
output = chain.invoke(inputs_, config=runnable_config)
return output