mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add run_id, run_name to RunnableConfig
This commit is contained in:
parent
a3c69cf41d
commit
f69155b4f7
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
|
||||
@ -68,6 +69,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_id=config.get("run_id"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -89,6 +92,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_id=config.get("run_id"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -235,6 +240,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
@ -276,6 +283,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
@ -302,6 +311,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
@ -343,6 +354,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
|
@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
from langchain.schema.runnable.config import get_config_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
config = self._get_config_list(config, len(inputs))
|
||||
config = get_config_list(config, len(inputs))
|
||||
|
||||
if max_concurrency is None:
|
||||
llm_result = self.generate_prompt(
|
||||
@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
None, self.batch, inputs, config, max_concurrency
|
||||
)
|
||||
|
||||
config = self._get_config_list(config, len(inputs))
|
||||
config = get_config_list(config, len(inputs))
|
||||
|
||||
if max_concurrency is None:
|
||||
llm_result = await self.agenerate_prompt(
|
||||
|
@ -4,6 +4,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
@ -164,6 +165,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
@ -193,6 +195,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
@ -220,6 +223,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
@ -249,6 +253,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
query,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
@ -42,6 +42,7 @@ from langchain.schema.runnable.config import (
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
patch_config,
|
||||
)
|
||||
@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Default implementation of batch, which calls invoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Default implementation of abatch, which calls ainvoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
||||
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
@ -246,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
""" --- Helper methods for Subclasses --- """
|
||||
|
||||
def _get_config_list(
|
||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
"""
|
||||
Helper method to get a list of configs from a single config or a list of
|
||||
configs, useful for subclasses overriding batch() or abatch().
|
||||
"""
|
||||
if length < 1:
|
||||
raise ValueError(f"length must be >= 1, but got {length}")
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
f"but got {len(config)} configs for {length} inputs"
|
||||
)
|
||||
|
||||
return (
|
||||
list(map(ensure_config, config))
|
||||
if isinstance(config, list)
|
||||
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
self,
|
||||
func: Union[
|
||||
@ -286,6 +266,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
@ -327,6 +309,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
@ -384,6 +368,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(transformer):
|
||||
@ -464,6 +450,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
# mypy can't quite work out thew type guard here, but this is safe,
|
||||
@ -539,7 +527,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
@ -571,7 +561,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
@ -603,7 +595,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -619,9 +611,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
first_error = None
|
||||
@ -661,7 +656,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -677,8 +672,13 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(dumpd(self), input)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
@ -783,7 +783,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
@ -811,7 +813,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
@ -838,7 +842,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -853,8 +857,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(dumpd(self), input)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
# invoke
|
||||
@ -889,7 +898,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
@ -905,8 +914,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(dumpd(self), input)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_id=config.get("run_id"),
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
@ -942,7 +956,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = 0
|
||||
@ -1009,7 +1025,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = len(steps) - 1
|
||||
@ -1140,7 +1158,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
@ -1179,7 +1199,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
@ -1529,7 +1551,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
return self.bound.invoke(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
@ -1539,7 +1563,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
return await self.bound.ainvoke(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
def batch(
|
||||
@ -1548,13 +1574,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
configs = (
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
[{**self.config, **(conf or {})} for conf in config]
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
|
||||
patch_config(
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
deep_copy_locals=True,
|
||||
)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
],
|
||||
)
|
||||
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||
|
||||
@ -1564,19 +1594,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
configs = (
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
[{**self.config, **(conf or {})} for conf in config]
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
|
||||
patch_config(
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
deep_copy_locals=True,
|
||||
)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
)
|
||||
return await self.bound.abatch(
|
||||
inputs,
|
||||
[{**self.config, **(conf or {})} for conf in configs],
|
||||
**{**self.kwargs, **kwargs},
|
||||
],
|
||||
)
|
||||
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@ -1585,7 +1615,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.stream(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
async def astream(
|
||||
@ -1595,7 +1627,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.astream(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
|
||||
@ -1606,7 +1640,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
@ -1616,7 +1652,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
|
||||
|
@ -3,7 +3,9 @@ from __future__ import annotations
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -32,6 +34,16 @@ class RunnableConfig(TypedDict, total=False):
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
run_name: str
|
||||
"""
|
||||
Name for the tracer run for this call. Defaults to the name of the class.
|
||||
"""
|
||||
|
||||
run_id: UUID
|
||||
"""
|
||||
Unique ID for the tracer run for this call. Defaults to uuid4().
|
||||
"""
|
||||
|
||||
_locals: Dict[str, Any]
|
||||
"""
|
||||
Local variables
|
||||
@ -62,6 +74,28 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
return empty
|
||||
|
||||
|
||||
def get_config_list(
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
"""
|
||||
Helper method to get a list of configs from a single config or a list of
|
||||
configs, useful for subclasses overriding batch() or abatch().
|
||||
"""
|
||||
if length < 1:
|
||||
raise ValueError(f"length must be >= 1, but got {length}")
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
f"but got {len(config)} configs for {length} inputs"
|
||||
)
|
||||
|
||||
return (
|
||||
list(map(ensure_config, config))
|
||||
if isinstance(config, list)
|
||||
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
||||
)
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: Optional[RunnableConfig],
|
||||
*,
|
||||
|
@ -23,7 +23,7 @@ from langchain.schema.runnable.base import (
|
||||
RunnableSequence,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.config import RunnableConfig, get_config_list
|
||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ class RouterRunnable(
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
@ -156,7 +156,7 @@ class RouterRunnable(
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
return await gather_with_concurrency(
|
||||
max_concurrency,
|
||||
*(
|
||||
|
@ -8,6 +8,7 @@ from abc import abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@ -297,6 +298,7 @@ class ChildTool(BaseTool):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
@ -320,6 +322,7 @@ class ChildTool(BaseTool):
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
@ -370,6 +373,7 @@ class ChildTool(BaseTool):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
@ -392,6 +396,7 @@ class ChildTool(BaseTool):
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user