Add run_id, run_name to RunnableConfig

This commit is contained in:
Nuno Campos 2023-08-24 12:50:37 +02:00
parent a3c69cf41d
commit f69155b4f7
7 changed files with 162 additions and 66 deletions

View File

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import yaml import yaml
@ -68,6 +69,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_id=config.get("run_id"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
@ -89,6 +92,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_id=config.get("run_id"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
@ -235,6 +240,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
run_name: Optional[str] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Execute the chain. """Execute the chain.
@ -276,6 +283,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
run_id=run_id,
name=run_name,
) )
try: try:
outputs = ( outputs = (
@ -302,6 +311,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
run_name: Optional[str] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Asynchronously execute the chain. """Asynchronously execute the chain.
@ -343,6 +354,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
run_id=run_id,
name=run_name,
) )
try: try:
outputs = ( outputs = (

View File

@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.schema.runnable import RunnableConfig from langchain.schema.runnable import RunnableConfig
from langchain.schema.runnable.config import get_config_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
config = self._get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
if max_concurrency is None: if max_concurrency is None:
llm_result = self.generate_prompt( llm_result = self.generate_prompt(
@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, self.batch, inputs, config, max_concurrency None, self.batch, inputs, config, max_concurrency
) )
config = self._get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
if max_concurrency is None: if max_concurrency is None:
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(

View File

@ -4,6 +4,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
@ -164,6 +165,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Retrieve documents relevant to a query. """Retrieve documents relevant to a query.
@ -193,6 +195,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = callback_manager.on_retriever_start( run_manager = callback_manager.on_retriever_start(
dumpd(self), dumpd(self),
query, query,
run_id=run_id,
**kwargs, **kwargs,
) )
try: try:
@ -220,6 +223,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
@ -249,6 +253,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = await callback_manager.on_retriever_start( run_manager = await callback_manager.on_retriever_start(
dumpd(self), dumpd(self),
query, query,
run_id=run_id,
**kwargs, **kwargs,
) )
try: try:

View File

@ -42,6 +42,7 @@ from langchain.schema.runnable.config import (
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
get_config_list,
get_executor_for_config, get_executor_for_config,
patch_config, patch_config,
) )
@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of batch, which calls invoke N times. Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of abatch, which calls ainvoke N times. Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs) coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
@ -246,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC):
""" --- Helper methods for Subclasses --- """ """ --- Helper methods for Subclasses --- """
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def _call_with_config( def _call_with_config(
self, self,
func: Union[ func: Union[
@ -286,6 +266,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
input, input,
run_type=run_type, run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(func): if accepts_run_manager_and_config(func):
@ -327,6 +309,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
input, input,
run_type=run_type, run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(func): if accepts_run_manager_and_config(func):
@ -384,6 +368,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
run_type=run_type, run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(transformer): if accepts_run_manager_and_config(transformer):
@ -464,6 +450,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
run_type=run_type, run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
) )
try: try:
# mypy can't quite work out thew type guard here, but this is safe, # mypy can't quite work out thew type guard here, but this is safe,
@ -539,7 +527,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
first_error = None first_error = None
for runnable in self.runnables: for runnable in self.runnables:
try: try:
@ -571,7 +561,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
first_error = None first_error = None
for runnable in self.runnables: for runnable in self.runnables:
@ -603,7 +595,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -619,9 +611,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input # start the root runs, one per input
run_managers = [ run_managers = [
cm.on_chain_start( cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_id=config.get("run_id"),
name=config.get("run_name"),
) )
for cm, input in zip(callback_managers, inputs) for cm, input, config in zip(callback_managers, inputs, configs)
] ]
first_error = None first_error = None
@ -661,7 +656,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -677,8 +672,13 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input # start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
cm.on_chain_start(dumpd(self), input) cm.on_chain_start(
for cm, input in zip(callback_managers, inputs) dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
) )
) )
@ -783,7 +783,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# invoke all steps in sequence # invoke all steps in sequence
try: try:
@ -811,7 +813,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# invoke all steps in sequence # invoke all steps in sequence
try: try:
@ -838,7 +842,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -853,8 +857,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
] ]
# start the root runs, one per input # start the root runs, one per input
run_managers = [ run_managers = [
cm.on_chain_start(dumpd(self), input) cm.on_chain_start(
for cm, input in zip(callback_managers, inputs) dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
] ]
# invoke # invoke
@ -889,7 +898,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) )
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -905,8 +914,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# start the root runs, one per input # start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
cm.on_chain_start(dumpd(self), input) cm.on_chain_start(
for cm, input in zip(callback_managers, inputs) dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
) )
) )
@ -942,7 +956,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0 streaming_start_index = 0
@ -1009,7 +1025,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1 streaming_start_index = len(steps) - 1
@ -1140,7 +1158,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata=None, local_metadata=None,
) )
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# gather results from all steps # gather results from all steps
try: try:
@ -1179,7 +1199,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# gather results from all steps # gather results from all steps
try: try:
@ -1529,7 +1551,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return self.bound.invoke( return self.bound.invoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
) )
async def ainvoke( async def ainvoke(
@ -1539,7 +1563,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return await self.bound.ainvoke( return await self.bound.ainvoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
) )
def batch( def batch(
@ -1548,13 +1574,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
configs = ( configs = cast(
List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config] [{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list) if isinstance(config, list)
else [ else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True) patch_config(
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
for _ in range(len(inputs)) for _ in range(len(inputs))
] ],
) )
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
@ -1564,19 +1594,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
configs = ( configs = cast(
List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config] [{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list) if isinstance(config, list)
else [ else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True) patch_config(
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
for _ in range(len(inputs)) for _ in range(len(inputs))
] ],
)
return await self.bound.abatch(
inputs,
[{**self.config, **(conf or {})} for conf in configs],
**{**self.kwargs, **kwargs},
) )
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
def stream( def stream(
self, self,
@ -1585,7 +1615,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.stream( yield from self.bound.stream(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
) )
async def astream( async def astream(
@ -1595,7 +1627,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.astream( async for item in self.bound.astream(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
): ):
yield item yield item
@ -1606,7 +1640,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any, **kwargs: Any,
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.transform( yield from self.bound.transform(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
) )
async def atransform( async def atransform(
@ -1616,7 +1652,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.atransform( async for item in self.bound.atransform(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
): ):
yield item yield item

View File

@ -3,7 +3,9 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
from uuid import UUID
from typing_extensions import TypedDict from typing_extensions import TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -32,6 +34,16 @@ class RunnableConfig(TypedDict, total=False):
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
""" """
run_name: str
"""
Name for the tracer run for this call. Defaults to the name of the class.
"""
run_id: UUID
"""
Unique ID for the tracer run for this call. Defaults to uuid4().
"""
_locals: Dict[str, Any] _locals: Dict[str, Any]
""" """
Local variables Local variables
@ -62,6 +74,28 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
return empty return empty
def get_config_list(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def patch_config( def patch_config(
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
*, *,

View File

@ -23,7 +23,7 @@ from langchain.schema.runnable.base import (
RunnableSequence, RunnableSequence,
coerce_to_runnable, coerce_to_runnable,
) )
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig, get_config_list
from langchain.schema.runnable.utils import gather_with_concurrency from langchain.schema.runnable.utils import gather_with_concurrency
@ -131,7 +131,7 @@ class RouterRunnable(
raise ValueError("One or more keys do not have a corresponding runnable") raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys] runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor: with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list( return list(
executor.map( executor.map(
@ -156,7 +156,7 @@ class RouterRunnable(
raise ValueError("One or more keys do not have a corresponding runnable") raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys] runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
return await gather_with_concurrency( return await gather_with_concurrency(
max_concurrency, max_concurrency,
*( *(

View File

@ -8,6 +8,7 @@ from abc import abstractmethod
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -297,6 +298,7 @@ class ChildTool(BaseTool):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool.""" """Run the tool."""
@ -320,6 +322,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
run_id=run_id,
**kwargs, **kwargs,
) )
try: try:
@ -370,6 +373,7 @@ class ChildTool(BaseTool):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
@ -392,6 +396,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
run_id=run_id,
**kwargs, **kwargs,
) )
try: try: