mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
wip
This commit is contained in:
parent
539672a7fd
commit
50b13ab938
@ -62,7 +62,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self(input, **(config or {}), **kwargs)
|
||||
_config: Dict[str, Any] = dict(config) if config else {}
|
||||
_config.pop("_locals", None)
|
||||
return self(input, **_config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -76,7 +78,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
return await self.acall(input, **(config or {}), **kwargs)
|
||||
_config: Dict[str, Any] = dict(config) if config else {}
|
||||
_config.pop("_locals", None)
|
||||
return await self.acall(input, **_config, **kwargs)
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from itertools import tee
|
||||
from typing import (
|
||||
Any,
|
||||
@ -66,6 +67,35 @@ class RunnableConfig(TypedDict, total=False):
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
_locals: Dict[str, Any]
|
||||
"""
|
||||
Local variables
|
||||
"""
|
||||
|
||||
|
||||
def _empty_config() -> RunnableConfig:
|
||||
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||
|
||||
|
||||
def _get_callback_manager(config: Mapping) -> Any:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
return CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def _get_async_callback_manager(config: Mapping) -> Any:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
return AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
@ -243,7 +273,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return (
|
||||
config
|
||||
if isinstance(config, list)
|
||||
else [config.copy() if config is not None else {} for _ in range(length)]
|
||||
else [deepcopy(config) if config is not None else {} for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
@ -255,14 +285,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
@ -288,14 +312,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
@ -322,8 +340,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||
Output values, with callbacks.
|
||||
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = tee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
@ -333,11 +349,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
@ -393,8 +405,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||
Iterator of Output values, with callbacks.
|
||||
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = atee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
@ -404,11 +414,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
@ -473,19 +479,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
yield from self.fallbacks
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -516,19 +512,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -751,19 +737,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -771,11 +747,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
callbacks = run_manager.get_child()
|
||||
for step in self.steps:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
_patch_config(config, callbacks),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -790,19 +767,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -946,19 +913,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -1023,19 +980,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -1173,19 +1120,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), {"input": input}
|
||||
@ -1464,10 +1401,11 @@ class RouterRunnable(
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None
|
||||
) -> RunnableConfig:
|
||||
config = config.copy()
|
||||
config = deepcopy(config)
|
||||
config["callbacks"] = callback_manager
|
||||
config["_locals"] = _locals or {}
|
||||
return config
|
||||
|
||||
|
||||
|
@ -636,7 +636,9 @@ async def _arun_chain(
|
||||
else:
|
||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
runnable_config = RunnableConfig(
|
||||
tags=tags or [], callbacks=callbacks, _locals={}
|
||||
)
|
||||
output = await chain.ainvoke(inputs_, config=runnable_config)
|
||||
return output
|
||||
|
||||
@ -957,7 +959,9 @@ def _run_chain(
|
||||
else:
|
||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
runnable_config = RunnableConfig(
|
||||
tags=tags or [], callbacks=callbacks, _locals={}
|
||||
)
|
||||
output = chain.invoke(inputs_, config=runnable_config)
|
||||
return output
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user