[Enhancement] Add support for directly providing a run_id (#18990)

The root run id (~trace id's) is useful for assigning feedback, but the
current recommended approach is to use callbacks to retrieve it, which
has some drawbacks:
1. Doesn't work for streaming until after the first event
2. Doesn't let you call other endpoints with the same trace ID in
parallel (since you have to wait until the call is completed/started to
use

This PR lets you provide = "run_id" in the runnable config.

Couple considerations:

1. For batch calls, we split the trace up into separate trees (to permit
better rendering). We keep the provided run ID for the first one and
generate a unique one for other elements of the batch.
2. For nested calls, the provided ID is ONLY used on the top root/trace.



### Example Usage


```
chain.invoke("foo", {"run_id": uuid.uuid4()})
```
pull/19216/head^2
William FH 3 months ago committed by GitHub
parent bd329e9aad
commit 780337488e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1183,6 +1183,7 @@ class CallbackManager(BaseCallbackManager):
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
prompts: List[str], prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[CallbackManagerForLLMRun]: ) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
@ -1197,8 +1198,9 @@ class CallbackManager(BaseCallbackManager):
prompt as an LLM run. prompt as an LLM run.
""" """
managers = [] managers = []
for prompt in prompts: for i, prompt in enumerate(prompts):
run_id_ = uuid.uuid4() # Can't have duplicate runs with the same run ID (if provided)
run_id_ = run_id if i == 0 and run_id is not None else uuid.uuid4()
handle_event( handle_event(
self.handlers, self.handlers,
"on_llm_start", "on_llm_start",
@ -1231,6 +1233,7 @@ class CallbackManager(BaseCallbackManager):
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[CallbackManagerForLLMRun]: ) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
@ -1247,6 +1250,10 @@ class CallbackManager(BaseCallbackManager):
managers = [] managers = []
for message_list in messages: for message_list in messages:
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4() run_id_ = uuid.uuid4()
handle_event( handle_event(
self.handlers, self.handlers,
@ -1520,6 +1527,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
prompts: List[str], prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]: ) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
@ -1539,6 +1547,10 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = [] managers = []
for prompt in prompts: for prompt in prompts:
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4() run_id_ = uuid.uuid4()
tasks.append( tasks.append(
@ -1577,6 +1589,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]: ) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
@ -1595,6 +1608,10 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = [] managers = []
for message_list in messages: for message_list in messages:
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4() run_id_ = uuid.uuid4()
tasks.append( tasks.append(

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import (
@ -234,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params, invocation_params=params,
options=options, options=options,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1, batch_size=1,
) )
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
@ -312,6 +314,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params, invocation_params=params,
options=options, options=options,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1, batch_size=1,
) )
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
@ -371,6 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Pass a sequence of prompts to the model and return model generations. """Pass a sequence of prompts to the model and return model generations.
@ -415,6 +419,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params, invocation_params=params,
options=options, options=options,
name=run_name, name=run_name,
run_id=run_id,
batch_size=len(messages), batch_size=len(messages),
) )
results = [] results = []
@ -456,6 +461,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations. """Asynchronously pass a sequence of prompts to a model and return generations.
@ -502,6 +508,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
options=options, options=options,
name=run_name, name=run_name,
batch_size=len(messages), batch_size=len(messages),
run_id=run_id,
) )
results = await asyncio.gather( results = await asyncio.gather(

@ -7,6 +7,7 @@ import functools
import inspect import inspect
import json import json
import logging import logging
import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
@ -271,6 +272,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs, **kwargs,
) )
.generations[0][0] .generations[0][0]
@ -293,6 +295,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs, **kwargs,
) )
return llm_result.generations[0][0].text return llm_result.generations[0][0].text
@ -423,6 +426,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params, invocation_params=params,
options=options, options=options,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1, batch_size=1,
) )
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
@ -499,6 +503,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params, invocation_params=params,
options=options, options=options,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1, batch_size=1,
) )
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
@ -632,6 +637,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags: Optional[Union[List[str], List[List[str]]]] = None, tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None, run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Pass a sequence of prompts to a model and return generations. """Pass a sequence of prompts to a model and return generations.
@ -717,7 +723,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
] * len(prompts) ] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts) run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop} options = {"stop": stop}
@ -744,9 +750,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
options=options, options=options,
name=run_name, name=run_name,
batch_size=len(prompts), batch_size=len(prompts),
run_id=run_id_,
)[0] )[0]
for callback_manager, prompt, run_name in zip( for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list callback_managers, prompts, run_name_list, run_ids_list
) )
] ]
output = self._generate_helper( output = self._generate_helper(
@ -782,6 +789,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info) return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
@staticmethod
def _get_run_ids_list(
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
) -> list:
if run_id is None:
return [None] * len(prompts)
if isinstance(run_id, list):
if len(run_id) != len(prompts):
raise ValueError(
"Number of manually provided run_id's does not match batch length."
f" {len(run_id)} != {len(prompts)}"
)
return run_id
return [run_id] + [None] * (len(prompts) - 1)
async def _agenerate_helper( async def _agenerate_helper(
self, self,
prompts: List[str], prompts: List[str],
@ -833,6 +855,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags: Optional[Union[List[str], List[List[str]]]] = None, tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None, run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations. """Asynchronously pass a sequence of prompts to a model and return generations.
@ -909,7 +932,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
] * len(prompts) ] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts) run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop} options = {"stop": stop}
@ -937,9 +960,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
options=options, options=options,
name=run_name, name=run_name,
batch_size=len(prompts), batch_size=len(prompts),
run_id=run_id_,
) )
for callback_manager, prompt, run_name in zip( for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list callback_managers, prompts, run_name_list, run_ids_list
) )
] ]
) )

@ -230,6 +230,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self), dumpd(self),
query, query,
name=run_name, name=run_name,
run_id=kwargs.pop("run_id", None),
) )
try: try:
_kwargs = kwargs if self._expects_other_args else {} _kwargs = kwargs if self._expects_other_args else {}
@ -286,6 +287,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self), dumpd(self),
query, query,
name=run_name, name=run_name,
run_id=kwargs.pop("run_id", None),
) )
try: try:
_kwargs = kwargs if self._expects_other_args else {} _kwargs = kwargs if self._expects_other_args else {}

@ -1448,6 +1448,7 @@ class Runnable(Generic[Input, Output], ABC):
input, input,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1495,6 +1496,7 @@ class Runnable(Generic[Input, Output], ABC):
input, input,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1547,6 +1549,7 @@ class Runnable(Generic[Input, Output], ABC):
input, input,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
for callback_manager, input, config in zip( for callback_manager, input, config in zip(
callback_managers, input, configs callback_managers, input, configs
@ -1619,6 +1622,7 @@ class Runnable(Generic[Input, Output], ABC):
input, input,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
for callback_manager, input, config in zip( for callback_manager, input, config in zip(
callback_managers, input, configs callback_managers, input, configs
@ -1694,6 +1698,7 @@ class Runnable(Generic[Input, Output], ABC):
{"input": ""}, {"input": ""},
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1781,6 +1786,7 @@ class Runnable(Generic[Input, Output], ABC):
{"input": ""}, {"input": ""},
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
@ -2262,7 +2268,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name() dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
# invoke all steps in sequence # invoke all steps in sequence
@ -2296,7 +2305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name() dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
# invoke all steps in sequence # invoke all steps in sequence
@ -2354,6 +2366,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
] ]
@ -2478,6 +2491,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
) )
@ -2885,7 +2899,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
) )
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name() dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
# gather results from all steps # gather results from all steps
@ -2925,7 +2942,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name() dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
) )
# gather results from all steps # gather results from all steps

@ -183,6 +183,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
try: try:
@ -231,6 +232,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
try: try:
for idx, branch in enumerate(self.branches): for idx, branch in enumerate(self.branches):
@ -282,6 +284,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
@ -356,6 +359,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import uuid
import warnings
from concurrent.futures import Executor, Future, ThreadPoolExecutor from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar, copy_context from contextvars import ContextVar, copy_context
@ -95,6 +97,12 @@ class RunnableConfig(TypedDict, total=False):
configurable. configurable.
""" """
run_id: Optional[uuid.UUID]
"""
Unique identifier for the tracer run for this call. If not provided, a new UUID
will be generated.
"""
var_child_runnable_config = ContextVar( var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig() "child_runnable_config", default=RunnableConfig()
@ -116,6 +124,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
metadata={}, metadata={},
callbacks=None, callbacks=None,
recursion_limit=25, recursion_limit=25,
run_id=None,
) )
if var_config := var_child_runnable_config.get(): if var_config := var_child_runnable_config.get():
empty.update( empty.update(
@ -158,11 +167,21 @@ def get_config_list(
f"but got {len(config)} configs for {length} inputs" f"but got {len(config)} configs for {length} inputs"
) )
return ( if isinstance(config, list):
list(map(ensure_config, config)) return list(map(ensure_config, config))
if isinstance(config, list) if length > 1 and isinstance(config, dict) and config.get("run_id") is not None:
else [ensure_config(config) for _ in range(length)] warnings.warn(
"Provided run_id be used only for the first element of the batch.",
category=RuntimeWarning,
)
subsequent = cast(
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
) )
return [
ensure_config(subsequent) if i else ensure_config(config)
for i in range(length)
]
return [ensure_config(config) for i in range(length)]
def patch_config( def patch_config(
@ -199,6 +218,8 @@ def patch_config(
config["callbacks"] = callbacks config["callbacks"] = callbacks
if "run_name" in config: if "run_name" in config:
del config["run_name"] del config["run_name"]
if "run_id" in config:
del config["run_id"]
if recursion_limit is not None: if recursion_limit is not None:
config["recursion_limit"] = recursion_limit config["recursion_limit"] = recursion_limit
if max_concurrency is not None: if max_concurrency is not None:

@ -156,7 +156,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
first_error = None first_error = None
last_error = None last_error = None
@ -200,7 +203,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
first_error = None first_error = None
@ -270,6 +276,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input if isinstance(input, dict) else {"input": input}, input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
] ]
@ -362,6 +369,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
) )
@ -436,7 +444,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
first_error = None first_error = None
last_error = None last_error = None
@ -493,7 +504,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
) )
first_error = None first_error = None
last_error = None last_error = None

@ -0,0 +1,15 @@
# from langchain_core.runnables.base import RunnableBinding
# class RunnableLearnable(RunnableBinding):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.parameters = []
# def backward(self):
# for param in self.parameters:
# param.backward()
# def update(self, optimizer):
# for param in self.parameters:
# optimizer.update(param)

@ -20,6 +20,7 @@ tool for the job.
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import uuid
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from inspect import signature from inspect import signature
@ -243,6 +244,7 @@ class ChildTool(BaseTool):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs, **kwargs,
) )
@ -259,6 +261,7 @@ class ChildTool(BaseTool):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs, **kwargs,
) )
@ -339,6 +342,7 @@ class ChildTool(BaseTool):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool.""" """Run the tool."""
@ -362,6 +366,7 @@ class ChildTool(BaseTool):
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name, name=run_name,
run_id=run_id,
# Inputs by definition should always be dicts. # Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated, # For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead # but if it is we will send a `None` value to the callback instead
@ -430,6 +435,7 @@ class ChildTool(BaseTool):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
@ -453,6 +459,7 @@ class ChildTool(BaseTool):
color=start_color, color=start_color,
name=run_name, name=run_name,
inputs=tool_input, inputs=tool_input,
run_id=run_id,
**kwargs, **kwargs,
) )
try: try:

@ -1,4 +1,5 @@
import sys import sys
import uuid
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
@ -136,6 +137,22 @@ class FakeTracer(BaseTracer):
self.runs.append(self._copy_run(run)) self.runs.append(self._copy_run(run))
def flattened_runs(self) -> List[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result
@property
def run_ids(self) -> List[Optional[uuid.UUID]]:
runs = self.flattened_runs()
uuids_map = {v: k for k, v in self.uuids_map.items()}
return [uuids_map.get(r.id) for r in runs]
class FakeRunnable(Runnable[str, int]): class FakeRunnable(Runnable[str, int]):
def invoke( def invoke(
@ -1367,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
recursion_limit=25, recursion_limit=25,
configurable={"hello": "there"}, configurable={"hello": "there"},
metadata={"hello": "there", "bye": "now"}, metadata={"hello": "there", "bye": "now"},
run_id=None,
), ),
) )
spy.reset_mock() spy.reset_mock()
@ -1508,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
recursion_limit=5, recursion_limit=5,
run_id=None,
), ),
), ),
mocker.call( mocker.call(
@ -1517,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
recursion_limit=5, recursion_limit=5,
run_id=None,
), ),
), ),
] ]
@ -1542,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
recursion_limit=5, recursion_limit=5,
run_id=None,
), ),
) )
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld") second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
@ -1552,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
recursion_limit=5, recursion_limit=5,
run_id=None,
), ),
) )
@ -1620,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[], tags=[],
callbacks=None, callbacks=None,
recursion_limit=25, recursion_limit=25,
run_id=None,
), ),
), ),
mocker.call( mocker.call(
@ -1629,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[], tags=[],
callbacks=None, callbacks=None,
recursion_limit=25, recursion_limit=25,
run_id=None,
), ),
), ),
] ]
@ -4822,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None:
} }
tracer = FakeTracer() tracer = FakeTracer()
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6 run_id = uuid.uuid4()
assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1 assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3 assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear() tracer.runs.clear()
assert list(runnable.stream(None)) == [1, 2, 3] assert list(runnable.stream(None)) == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer() tracer = FakeTracer()
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3] run_id = uuid.uuid4()
assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
1,
2,
3,
]
assert len(tracer.runs) == 1 assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3 assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
tracer = FakeTracer() tracer = FakeTracer()
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6] run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert runnable.batch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6} assert tracer.runs[1].outputs == {"output": 6}
@ -4865,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None:
arunnable = RunnableGenerator(agen) arunnable = RunnableGenerator(agen)
tracer = FakeTracer() tracer = FakeTracer()
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1 assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3 assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear() tracer.runs.clear()
assert [p async for p in arunnable.astream(None)] == [1, 2, 3] assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer() tracer = FakeTracer()
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [ run_id = uuid.uuid4()
assert [
p
async for p in arunnable.astream(
None, {"callbacks": [tracer], "run_id": run_id}
)
] == [
1, 1,
2, 2,
3, 3,
@ -4887,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None:
assert len(tracer.runs[0].child_runs) == 3 assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer = FakeTracer() tracer = FakeTracer()
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6] run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert await arunnable.abatch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6} assert tracer.runs[1].outputs == {"output": 6}

Loading…
Cancel
Save