Add run_id, run_name to RunnableConfig

This commit is contained in:
Nuno Campos 2023-08-24 12:50:37 +02:00
parent a3c69cf41d
commit f69155b4f7
7 changed files with 162 additions and 66 deletions

View File

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import yaml
@ -68,6 +69,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_id=config.get("run_id"),
run_name=config.get("run_name"),
**kwargs,
)
@ -89,6 +92,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_id=config.get("run_id"),
run_name=config.get("run_name"),
**kwargs,
)
@ -235,6 +240,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Execute the chain.
@ -276,6 +283,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id=run_id,
name=run_name,
)
try:
outputs = (
@ -302,6 +311,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
@ -343,6 +354,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id=run_id,
name=run_name,
)
try:
outputs = (

View File

@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain.schema.output import GenerationChunk
from langchain.schema.runnable import RunnableConfig
from langchain.schema.runnable.config import get_config_list
logger = logging.getLogger(__name__)
@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
config = self._get_config_list(config, len(inputs))
config = get_config_list(config, len(inputs))
if max_concurrency is None:
llm_result = self.generate_prompt(
@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, self.batch, inputs, config, max_concurrency
)
config = self._get_config_list(config, len(inputs))
config = get_config_list(config, len(inputs))
if max_concurrency is None:
llm_result = await self.agenerate_prompt(

View File

@ -4,6 +4,7 @@ import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
@ -164,6 +165,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
@ -193,6 +195,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
run_id=run_id,
**kwargs,
)
try:
@ -220,6 +223,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
@ -249,6 +253,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
run_id=run_id,
**kwargs,
)
try:

View File

@ -42,6 +42,7 @@ from langchain.schema.runnable.config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
get_executor_for_config,
patch_config,
)
@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
@ -246,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC):
""" --- Helper methods for Subclasses --- """
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def _call_with_config(
self,
func: Union[
@ -286,6 +266,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
if accepts_run_manager_and_config(func):
@ -327,6 +309,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
if accepts_run_manager_and_config(func):
@ -384,6 +368,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
{"input": ""},
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
if accepts_run_manager_and_config(transformer):
@ -464,6 +450,8 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
{"input": ""},
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
# mypy can't quite work out thew type guard here, but this is safe,
@ -539,7 +527,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
@ -571,7 +561,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
@ -603,7 +595,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -619,9 +611,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input in zip(callback_managers, inputs)
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
@ -661,7 +656,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -677,8 +672,13 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(dumpd(self), input)
for cm, input in zip(callback_managers, inputs)
cm.on_chain_start(
dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
@ -783,7 +783,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# invoke all steps in sequence
try:
@ -811,7 +813,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# invoke all steps in sequence
try:
@ -838,7 +842,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -853,8 +857,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(dumpd(self), input)
for cm, input in zip(callback_managers, inputs)
cm.on_chain_start(
dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
# invoke
@ -889,7 +898,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
# setup callbacks
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
@ -905,8 +914,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(dumpd(self), input)
for cm, input in zip(callback_managers, inputs)
cm.on_chain_start(
dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
@ -942,7 +956,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0
@ -1009,7 +1025,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1
@ -1140,7 +1158,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# gather results from all steps
try:
@ -1179,7 +1199,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
)
# gather results from all steps
try:
@ -1529,7 +1551,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Output:
return self.bound.invoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
)
async def ainvoke(
@ -1539,7 +1563,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Output:
return await self.bound.ainvoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
)
def batch(
@ -1548,13 +1574,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Output]:
configs = (
configs = cast(
List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list)
else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
patch_config(
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
for _ in range(len(inputs))
]
],
)
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
@ -1564,19 +1594,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Output]:
configs = (
configs = cast(
List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list)
else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
patch_config(
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
for _ in range(len(inputs))
]
)
return await self.bound.abatch(
inputs,
[{**self.config, **(conf or {})} for conf in configs],
**{**self.kwargs, **kwargs},
],
)
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
def stream(
self,
@ -1585,7 +1615,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self.bound.stream(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
)
async def astream(
@ -1595,7 +1627,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
):
yield item
@ -1606,7 +1640,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any,
) -> Iterator[Output]:
yield from self.bound.transform(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
)
async def atransform(
@ -1616,7 +1652,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any,
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
**{**self.kwargs, **kwargs},
):
yield item

View File

@ -3,7 +3,9 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
from uuid import UUID
from typing_extensions import TypedDict
if TYPE_CHECKING:
@ -32,6 +34,16 @@ class RunnableConfig(TypedDict, total=False):
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
run_name: str
"""
Name for the tracer run for this call. Defaults to the name of the class.
"""
run_id: UUID
"""
Unique ID for the tracer run for this call. Defaults to uuid4().
"""
_locals: Dict[str, Any]
"""
Local variables
@ -62,6 +74,28 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
return empty
def get_config_list(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def patch_config(
config: Optional[RunnableConfig],
*,

View File

@ -23,7 +23,7 @@ from langchain.schema.runnable.base import (
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.config import RunnableConfig, get_config_list
from langchain.schema.runnable.utils import gather_with_concurrency
@ -131,7 +131,7 @@ class RouterRunnable(
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
@ -156,7 +156,7 @@ class RouterRunnable(
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
configs = get_config_list(config, len(inputs))
return await gather_with_concurrency(
max_concurrency,
*(

View File

@ -8,6 +8,7 @@ from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@ -297,6 +298,7 @@ class ChildTool(BaseTool):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -320,6 +322,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
run_id=run_id,
**kwargs,
)
try:
@ -370,6 +373,7 @@ class ChildTool(BaseTool):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -392,6 +396,7 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
run_id=run_id,
**kwargs,
)
try: