[Enhancement] Add support for directly providing a run_id (#18990)

The root run id (~trace id's) is useful for assigning feedback, but the
current recommended approach is to use callbacks to retrieve it, which
has some drawbacks:
1. Doesn't work for streaming until after the first event
2. Doesn't let you call other endpoints with the same trace ID in
parallel (since you have to wait until the call is completed/started to
use

This PR lets you provide = "run_id" in the runnable config.

Couple considerations:

1. For batch calls, we split the trace up into separate trees (to permit
better rendering). We keep the provided run ID for the first one and
generate a unique one for other elements of the batch.
2. For nested calls, the provided ID is ONLY used on the top root/trace.



### Example Usage


```
chain.invoke("foo", {"run_id": uuid.uuid4()})
```
pull/19216/head^2
William FH 3 months ago committed by GitHub
parent bd329e9aad
commit 780337488e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1183,6 +1183,7 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1197,8 +1198,9 @@ class CallbackManager(BaseCallbackManager):
prompt as an LLM run.
"""
managers = []
for prompt in prompts:
run_id_ = uuid.uuid4()
for i, prompt in enumerate(prompts):
# Can't have duplicate runs with the same run ID (if provided)
run_id_ = run_id if i == 0 and run_id is not None else uuid.uuid4()
handle_event(
self.handlers,
"on_llm_start",
@ -1231,6 +1233,7 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1247,7 +1250,11 @@ class CallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid.uuid4()
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4()
handle_event(
self.handlers,
"on_chat_model_start",
@ -1520,6 +1527,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1539,7 +1547,11 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = []
for prompt in prompts:
run_id_ = uuid.uuid4()
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(
@ -1577,6 +1589,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1595,7 +1608,11 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid.uuid4()
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import inspect
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import (
@ -234,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
@ -312,6 +314,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
@ -371,6 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to the model and return model generations.
@ -415,6 +419,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=run_name,
run_id=run_id,
batch_size=len(messages),
)
results = []
@ -456,6 +461,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations.
@ -502,6 +508,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
options=options,
name=run_name,
batch_size=len(messages),
run_id=run_id,
)
results = await asyncio.gather(

@ -7,6 +7,7 @@ import functools
import inspect
import json
import logging
import uuid
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
@ -271,6 +272,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
.generations[0][0]
@ -293,6 +295,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
return llm_result.generations[0][0].text
@ -423,6 +426,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
@ -499,6 +503,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
@ -632,6 +637,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to a model and return generations.
@ -717,7 +723,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
@ -744,9 +750,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
options=options,
name=run_name,
batch_size=len(prompts),
run_id=run_id_,
)[0]
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list, run_ids_list
)
]
output = self._generate_helper(
@ -782,6 +789,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
@staticmethod
def _get_run_ids_list(
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
) -> list:
if run_id is None:
return [None] * len(prompts)
if isinstance(run_id, list):
if len(run_id) != len(prompts):
raise ValueError(
"Number of manually provided run_id's does not match batch length."
f" {len(run_id)} != {len(prompts)}"
)
return run_id
return [run_id] + [None] * (len(prompts) - 1)
async def _agenerate_helper(
self,
prompts: List[str],
@ -833,6 +855,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations.
@ -909,7 +932,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
@ -937,9 +960,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
options=options,
name=run_name,
batch_size=len(prompts),
run_id=run_id_,
)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list, run_ids_list
)
]
)

@ -230,6 +230,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
@ -286,6 +287,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}

@ -1448,6 +1448,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1495,6 +1496,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1547,6 +1549,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
@ -1619,6 +1622,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
@ -1694,6 +1698,7 @@ class Runnable(Generic[Input, Output], ABC):
{"input": ""},
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1781,6 +1786,7 @@ class Runnable(Generic[Input, Output], ABC):
{"input": ""},
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -2262,7 +2268,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
@ -2296,7 +2305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
@ -2354,6 +2366,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
@ -2478,6 +2491,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
@ -2885,7 +2899,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# gather results from all steps
@ -2925,7 +2942,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# gather results from all steps

@ -183,6 +183,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
try:
@ -231,6 +232,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
try:
for idx, branch in enumerate(self.branches):
@ -282,6 +284,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
final_output: Optional[Output] = None
final_output_supported = True
@ -356,6 +359,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
final_output: Optional[Output] = None
final_output_supported = True

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import uuid
import warnings
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
@ -95,6 +97,12 @@ class RunnableConfig(TypedDict, total=False):
configurable.
"""
run_id: Optional[uuid.UUID]
"""
Unique identifier for the tracer run for this call. If not provided, a new UUID
will be generated.
"""
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
@ -116,6 +124,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
metadata={},
callbacks=None,
recursion_limit=25,
run_id=None,
)
if var_config := var_child_runnable_config.get():
empty.update(
@ -158,11 +167,21 @@ def get_config_list(
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [ensure_config(config) for _ in range(length)]
)
if isinstance(config, list):
return list(map(ensure_config, config))
if length > 1 and isinstance(config, dict) and config.get("run_id") is not None:
warnings.warn(
"Provided run_id be used only for the first element of the batch.",
category=RuntimeWarning,
)
subsequent = cast(
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
)
return [
ensure_config(subsequent) if i else ensure_config(config)
for i in range(length)
]
return [ensure_config(config) for i in range(length)]
def patch_config(
@ -199,6 +218,8 @@ def patch_config(
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if "run_id" in config:
del config["run_id"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:

@ -156,7 +156,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
@ -200,7 +203,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
@ -270,6 +276,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
@ -362,6 +369,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
@ -436,7 +444,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
@ -493,7 +504,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None

@ -0,0 +1,15 @@
# from langchain_core.runnables.base import RunnableBinding
# class RunnableLearnable(RunnableBinding):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.parameters = []
# def backward(self):
# for param in self.parameters:
# param.backward()
# def update(self, optimizer):
# for param in self.parameters:
# optimizer.update(param)

@ -20,6 +20,7 @@ tool for the job.
from __future__ import annotations
import inspect
import uuid
import warnings
from abc import abstractmethod
from inspect import signature
@ -243,6 +244,7 @@ class ChildTool(BaseTool):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
@ -259,6 +261,7 @@ class ChildTool(BaseTool):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
@ -339,6 +342,7 @@ class ChildTool(BaseTool):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -362,6 +366,7 @@ class ChildTool(BaseTool):
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
run_id=run_id,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
@ -430,6 +435,7 @@ class ChildTool(BaseTool):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -453,6 +459,7 @@ class ChildTool(BaseTool):
color=start_color,
name=run_name,
inputs=tool_input,
run_id=run_id,
**kwargs,
)
try:

@ -1,4 +1,5 @@
import sys
import uuid
from functools import partial
from operator import itemgetter
from typing import (
@ -136,6 +137,22 @@ class FakeTracer(BaseTracer):
self.runs.append(self._copy_run(run))
def flattened_runs(self) -> List[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result
@property
def run_ids(self) -> List[Optional[uuid.UUID]]:
runs = self.flattened_runs()
uuids_map = {v: k for k, v in self.uuids_map.items()}
return [uuids_map.get(r.id) for r in runs]
class FakeRunnable(Runnable[str, int]):
def invoke(
@ -1367,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
recursion_limit=25,
configurable={"hello": "there"},
metadata={"hello": "there", "bye": "now"},
run_id=None,
),
)
spy.reset_mock()
@ -1508,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
mocker.call(
@ -1517,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
]
@ -1542,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
@ -1552,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
@ -1620,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
mocker.call(
@ -1629,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
]
@ -4822,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None:
}
tracer = FakeTracer()
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
assert list(runnable.stream(None)) == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
run_id = uuid.uuid4()
assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
1,
2,
3,
]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
tracer = FakeTracer()
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert runnable.batch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
@ -4865,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None:
arunnable = RunnableGenerator(agen)
tracer = FakeTracer()
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
run_id = uuid.uuid4()
assert [
p
async for p in arunnable.astream(
None, {"callbacks": [tracer], "run_id": run_id}
)
] == [
1,
2,
3,
@ -4887,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None:
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer = FakeTracer()
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert await arunnable.abatch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}

Loading…
Cancel
Save