core[minor]: Add an async root listener and with_alisteners method (#22151)

- [x] **Adding AsyncRootListener**: "langchain_core: Adding
AsyncRootListener"

- **Description:** Adding an AsyncBaseTracer, AsyncRootListener and
`with_alistener` function. This is to enable binding async root listener
to runnables. This currently only supported for sync listeners.
- **Issue:** None
- **Dependencies:** None

- [x] **Add tests and docs**: Added units tests and example snippet code
within the function description of `with_alistener`


- [x] **Lint and test**: Run make format_diff, make lint_diff and make
test
pull/22635/head
Nicolas Nkiere 4 months ago committed by GitHub
parent 2904c50cd5
commit 51005e2776
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -95,6 +95,7 @@ if TYPE_CHECKING:
RunLog,
RunLogPatch,
)
from langchain_core.tracers.root_listeners import AsyncListener
from langchain_core.tracers.schemas import Run
@ -1327,6 +1328,86 @@ class Runnable(Generic[Input, Output], ABC):
],
)
def with_alisteners(
self,
*,
on_start: Optional[AsyncListener] = None,
on_end: Optional[AsyncListener] = None,
on_error: Optional[AsyncListener] = None,
) -> Runnable[Input, Output]:
"""
Bind asynchronous lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Asynchronously called before the runnable starts running.
on_end: Asynchronously called after the runnable finishes running.
on_error: Asynchronously called if the runnable throws an error.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
Example:
.. code-block:: python
from langchain_core.runnables import RunnableLambda
import time
async def test_runnable(time_to_sleep : int):
print(f"Runnable[{time_to_sleep}s]: starts at {format_t(time.time())}")
await asyncio.sleep(time_to_sleep)
print(f"Runnable[{time_to_sleep}s]: ends at {format_t(time.time())}")
async def fn_start(run_obj : Runnable):
print(f"on start callback starts at {format_t(time.time())}
await asyncio.sleep(3)
print(f"on start callback ends at {format_t(time.time())}")
async def fn_end(run_obj : Runnable):
print(f"on end callback starts at {format_t(time.time())}
await asyncio.sleep(2)
print(f"on end callback ends at {format_t(time.time())}")
runnable = RunnableLambda(test_runnable).with_alisteners(
on_start=fn_start,
on_end=fn_end
)
async def concurrent_runs():
await asyncio.gather(runnable.ainvoke(2), runnable.ainvoke(3))
asyncio.run(concurrent_runs())
Result:
on start callback starts at 2024-05-16T14:20:29.637053+00:00
on start callback starts at 2024-05-16T14:20:29.637150+00:00
on start callback ends at 2024-05-16T14:20:32.638305+00:00
on start callback ends at 2024-05-16T14:20:32.638383+00:00
Runnable[3s]: starts at 2024-05-16T14:20:32.638849+00:00
Runnable[5s]: starts at 2024-05-16T14:20:32.638999+00:00
Runnable[3s]: ends at 2024-05-16T14:20:35.640016+00:00
on end callback starts at 2024-05-16T14:20:35.640534+00:00
Runnable[5s]: ends at 2024-05-16T14:20:37.640169+00:00
on end callback starts at 2024-05-16T14:20:37.640574+00:00
on end callback ends at 2024-05-16T14:20:37.640654+00:00
on end callback ends at 2024-05-16T14:20:39.641751+00:00
"""
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
return RunnableBinding(
bound=self,
config_factories=[
lambda config: {
"callbacks": [
AsyncRootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
],
)
def with_types(
self,
*,
@ -4294,6 +4375,33 @@ class RunnableEach(RunnableEachBase[Input, Output]):
)
)
def with_alisteners(
self,
*,
on_start: Optional[AsyncListener] = None,
on_end: Optional[AsyncListener] = None,
on_error: Optional[AsyncListener] = None,
) -> RunnableEach[Input, Output]:
"""
Bind async lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called asynchronously before the runnable starts running,
with the Run object.
on_end: Called asynchronously after the runnable finishes running,
with the Run object.
on_error: Called asynchronously if the runnable throws an error,
with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
return RunnableEach(
bound=self.bound.with_alisteners(
on_start=on_start, on_end=on_end, on_error=on_error
)
)
class RunnableBindingBase(RunnableSerializable[Input, Output]):
"""Runnable that delegates calls to another Runnable with a set of kwargs.

@ -2,38 +2,27 @@
from __future__ import annotations
import asyncio
import logging
import sys
import traceback
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from uuid import UUID
from tenacity import RetryCallState
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.exceptions import TracerException # noqa
from langchain_core.messages import BaseMessage
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.tracers.core import _TracerCore
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
@ -42,90 +31,16 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class BaseTracer(BaseCallbackHandler, ABC):
class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
"""Base interface for tracers."""
def __init__(
self,
*,
_schema_format: Literal[
"original", "streaming_events", "original+chat"
] = "original",
**kwargs: Any,
) -> None:
"""Initialize the tracer.
Args:
_schema_format: Primarily changes how the inputs and outputs are
handled. For internal use only. This API will change.
- 'original' is the format used by all current tracers.
This format is slightly inconsistent with respect to inputs
and outputs.
- 'streaming_events' is used for supporting streaming events,
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
- 'original+chat' is a format that is the same as 'original'
except it does NOT raise an attribute error on_chat_model_start
kwargs: Additional keyword arguments that will be passed to
the super class.
"""
super().__init__(**kwargs)
self._schema_format = _schema_format # For internal use only API will change.
self.run_map: Dict[str, Run] = {}
"""Map of run ID to run. Cleared on run end."""
self.order_map: Dict[UUID, Tuple[UUID, str]] = {}
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
@staticmethod
def _add_child_run(
parent_run: Run,
child_run: Run,
) -> None:
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
@abstractmethod
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
@staticmethod
def _get_stacktrace(error: BaseException) -> str:
"""Get the stacktrace of the parent error."""
msg = repr(error)
try:
if sys.version_info < (3, 10):
tb = traceback.format_exception(
error.__class__, error, error.__traceback__
)
else:
tb = traceback.format_exception(error)
return (msg + "\n\n".join(tb)).strip()
except: # noqa: E722
return msg
def _start_trace(self, run: Run) -> None:
"""Start a trace for a run."""
current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id)
if run.parent_run_id:
if parent := self.order_map.get(run.parent_run_id):
run.trace_id, run.dotted_order = parent
run.dotted_order += "." + current_dotted_order
if parent_run := self.run_map.get(str(run.parent_run_id)):
self._add_child_run(parent_run, run)
else:
logger.debug(
f"Parent run {run.parent_run_id} not found for run {run.id}."
" Treating as a root run."
)
run.parent_run_id = None
run.trace_id = run.id
run.dotted_order = current_dotted_order
else:
run.trace_id = run.id
run.dotted_order = current_dotted_order
self.order_map[run.id] = (run.trace_id, run.dotted_order)
self.run_map[str(run.id)] = run
super()._start_trace(run)
self._on_run_create(run)
def _end_trace(self, run: Run) -> None:
@ -135,25 +50,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
self.run_map.pop(str(run.id))
self._on_run_update(run)
def _get_run(
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
) -> Run:
try:
run = self.run_map[str(run_id)]
except KeyError as exc:
raise TracerException(f"No indexed run ID {run_id}.") from exc
if isinstance(run_type, str):
run_types: Union[Set[str], None] = {run_type}
else:
run_types = run_type
if run_types is not None and run.run_type not in run_types:
raise TracerException(
f"Found {run.run_type} run at ID {run_id}, "
f"but expected {run_types} run."
)
return run
def on_chat_model_start(
self,
serialized: Dict[str, Any],
@ -167,35 +63,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
if self._schema_format not in ("streaming_events", "original+chat"):
# Please keep this un-implemented for backwards compatibility.
# When it's unimplemented old tracers that use the "original" format
# fallback on the on_llm_start method implementation if they
# find that the on_chat_model_start method is not implemented.
# This can eventually be cleaned up by writing a "modern" tracer
# that has all the updated schema changes corresponding to
# the "streaming_events" format.
raise NotImplementedError(
f"Chat model tracing is not supported in "
f"for {self._schema_format} format."
)
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
chat_model_run = self._create_chat_model_run(
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
# WARNING: This is valid ONLY for streaming_events.
# run_type="llm" is what's used by virtually all tracers.
# Changing this to "chat_model" may break triggering on_llm_start
run_type="chat_model",
messages=messages,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
name=name, # type: ignore[arg-type]
metadata=metadata,
name=name,
**kwargs,
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)
@ -214,21 +90,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
llm_run = Run(
id=run_id,
parent_run_id=parent_run_id,
llm_run = self._create_llm_run(
serialized=serialized,
# TODO: Figure out how to expose kwargs here
inputs={"prompts": prompts},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
run_type="llm",
tags=tags or [],
name=name, # type: ignore[arg-type]
prompts=prompts,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
name=name,
**kwargs,
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
@ -246,16 +116,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Run on new LLM token. Only available when streaming is enabled."""
# "chat_model" is only used for the experimental new streaming_events format.
# This change should not affect any existing tracers.
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append(
{
"name": "new_token",
"time": datetime.now(timezone.utc),
"kwargs": event_kwargs,
},
llm_run = self._llm_run_with_token_event(
token=token,
run_id=run_id,
chunk=chunk,
parent_run_id=parent_run_id,
**kwargs,
)
self._on_llm_new_token(llm_run, token, chunk)
return llm_run
@ -267,27 +133,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
llm_run = self._get_run(run_id)
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
if retry_state.outcome is None:
retry_d["outcome"] = "N/A"
elif retry_state.outcome.failed:
retry_d["outcome"] = "failed"
exception = retry_state.outcome.exception()
retry_d["exception"] = str(exception)
retry_d["exception_type"] = exception.__class__.__name__
else:
retry_d["outcome"] = "success"
retry_d["result"] = str(retry_state.outcome.result())
llm_run.events.append(
{
"name": "retry",
"time": datetime.now(timezone.utc),
"kwargs": retry_d,
},
llm_run = self._llm_run_with_retry_event(
retry_state=retry_state,
run_id=run_id,
)
return llm_run
@ -295,17 +143,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""End a trace for an LLM run."""
# "chat_model" is only used for the experimental new streaming_events format.
# This change should not affect any existing tracers.
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]
if "message" in output_generation:
output_generation["message"] = dumpd(
cast(ChatGeneration, generation).message
)
llm_run.end_time = datetime.now(timezone.utc)
llm_run.events.append({"name": "end", "time": llm_run.end_time})
llm_run = self._complete_llm_run(
response=response,
run_id=run_id,
)
self._end_trace(llm_run)
self._on_llm_end(llm_run)
return llm_run
@ -320,10 +161,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Handle an error for an LLM run."""
# "chat_model" is only used for the experimental new streaming_events format.
# This change should not affect any existing tracers.
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.error = self._get_stacktrace(error)
llm_run.end_time = datetime.now(timezone.utc)
llm_run.events.append({"name": "error", "time": llm_run.end_time})
llm_run = self._errored_llm_run(
error=error,
run_id=run_id,
)
self._end_trace(llm_run)
self._on_llm_error(llm_run)
return llm_run
@ -342,48 +183,21 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Start a trace for a chain run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
chain_run = Run(
id=run_id,
parent_run_id=parent_run_id,
chain_run = self._create_chain_run(
serialized=serialized,
inputs=self._get_chain_inputs(inputs),
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
child_runs=[],
run_type=run_type or "chain",
name=name, # type: ignore[arg-type]
tags=tags or [],
inputs=inputs,
run_id=run_id,
tags=tags,
parent_run_id=parent_run_id,
metadata=metadata,
run_type=run_type,
name=name,
**kwargs,
)
self._start_trace(chain_run)
self._on_chain_start(chain_run)
return chain_run
def _get_chain_inputs(self, inputs: Any) -> Any:
"""Get the inputs for a chain run."""
if self._schema_format in ("original", "original+chat"):
return inputs if isinstance(inputs, dict) else {"input": inputs}
elif self._schema_format == "streaming_events":
return {
"input": inputs,
}
else:
raise ValueError(f"Invalid format: {self._schema_format}")
def _get_chain_outputs(self, outputs: Any) -> Any:
"""Get the outputs for a chain run."""
if self._schema_format in ("original", "original+chat"):
return outputs if isinstance(outputs, dict) else {"output": outputs}
elif self._schema_format == "streaming_events":
return {
"output": outputs,
}
else:
raise ValueError(f"Invalid format: {self._schema_format}")
def on_chain_end(
self,
outputs: Dict[str, Any],
@ -393,12 +207,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""End a trace for a chain run."""
chain_run = self._get_run(run_id)
chain_run.outputs = self._get_chain_outputs(outputs)
chain_run.end_time = datetime.now(timezone.utc)
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = self._get_chain_inputs(inputs)
chain_run = self._complete_chain_run(
outputs=outputs,
run_id=run_id,
inputs=inputs,
**kwargs,
)
self._end_trace(chain_run)
self._on_chain_end(chain_run)
return chain_run
@ -412,12 +226,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Handle an error for a chain run."""
chain_run = self._get_run(run_id)
chain_run.error = self._get_stacktrace(error)
chain_run.end_time = datetime.now(timezone.utc)
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = self._get_chain_inputs(inputs)
chain_run = self._errored_chain_run(
error=error,
run_id=run_id,
inputs=inputs,
**kwargs,
)
self._end_trace(chain_run)
self._on_chain_error(chain_run)
return chain_run
@ -436,30 +250,16 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Start a trace for a tool run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
if self._schema_format in ("original", "original+chat"):
inputs = {"input": input_str}
elif self._schema_format == "streaming_events":
inputs = {"input": inputs}
else:
raise AssertionError(f"Invalid format: {self._schema_format}")
tool_run = Run(
id=run_id,
parent_run_id=parent_run_id,
tool_run = self._create_tool_run(
serialized=serialized,
# Wrapping in dict since Run requires a dict object.
input_str=input_str,
run_id=run_id,
tags=tags,
parent_run_id=parent_run_id,
metadata=metadata,
name=name,
inputs=inputs,
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
child_runs=[],
run_type="tool",
tags=tags or [],
name=name, # type: ignore[arg-type]
**kwargs,
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)
@ -467,10 +267,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.now(timezone.utc)
tool_run.events.append({"name": "end", "time": tool_run.end_time})
tool_run = self._complete_tool_run(
output=output,
run_id=run_id,
**kwargs,
)
self._end_trace(tool_run)
self._on_tool_end(tool_run)
return tool_run
@ -483,10 +284,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Handle an error for a tool run."""
tool_run = self._get_run(run_id, run_type="tool")
tool_run.error = self._get_stacktrace(error)
tool_run.end_time = datetime.now(timezone.utc)
tool_run.events.append({"name": "error", "time": tool_run.end_time})
tool_run = self._errored_tool_run(
error=error,
run_id=run_id,
)
self._end_trace(tool_run)
self._on_tool_error(tool_run)
return tool_run
@ -504,21 +305,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Run when Retriever starts running."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
retrieval_run = Run(
id=run_id,
name=name or "Retriever",
parent_run_id=parent_run_id,
retrieval_run = self._create_retrieval_run(
serialized=serialized,
inputs={"query": query},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
query=query,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
child_runs=[],
run_type="retriever",
metadata=metadata,
name=name,
**kwargs,
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
@ -532,10 +327,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Run when Retriever errors."""
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = self._get_stacktrace(error)
retrieval_run.end_time = datetime.now(timezone.utc)
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
retrieval_run = self._errored_retrieval_run(
error=error,
run_id=run_id,
**kwargs,
)
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
return retrieval_run
@ -544,10 +340,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run:
"""Run when Retriever ends running."""
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.now(timezone.utc)
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
retrieval_run = self._complete_retrieval_run(
documents=documents,
run_id=run_id,
**kwargs,
)
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
return retrieval_run
@ -560,16 +357,349 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Copy the tracer."""
return self
def _on_run_create(self, run: Run) -> None:
class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
"""Async Base interface for tracers."""
@abstractmethod
async def _persist_run(self, run: Run) -> None:
"""Persist a run."""
async def _start_trace(self, run: Run) -> None:
"""
Start a trace for a run.
Starting a trace will run concurrently with each _on_[run_type]_start method.
No _on_[run_type]_start callback should depend on operations in _start_trace.
"""
super()._start_trace(run)
await self._on_run_create(run)
async def _end_trace(self, run: Run) -> None:
"""
End a trace for a run.
Ending a trace will run concurrently with each _on_[run_type]_end method.
No _on_[run_type]_end callback should depend on operations in _end_trace.
"""
if not run.parent_run_id:
await self._persist_run(run)
self.run_map.pop(str(run.id))
await self._on_run_update(run)
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
chat_model_run = self._create_chat_model_run(
serialized=serialized,
messages=messages,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
name=name,
**kwargs,
)
tasks = [
self._start_trace(chat_model_run),
self._on_chat_model_start(chat_model_run),
]
await asyncio.gather(*tasks)
return chat_model_run
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
llm_run = self._create_llm_run(
serialized=serialized,
prompts=prompts,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
**kwargs,
)
tasks = [self._start_trace(llm_run), self._on_llm_start(llm_run)]
await asyncio.gather(*tasks)
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
llm_run = self._llm_run_with_token_event(
token=token,
run_id=run_id,
chunk=chunk,
parent_run_id=parent_run_id,
**kwargs,
)
await self._on_llm_new_token(llm_run, token, chunk)
async def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
**kwargs: Any,
) -> None:
self._llm_run_with_retry_event(
retry_state=retry_state,
run_id=run_id,
)
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
llm_run = self._complete_llm_run(
response=response,
run_id=run_id,
)
tasks = [self._on_llm_end(llm_run), self._end_trace(llm_run)]
await asyncio.gather(*tasks)
async def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
llm_run = self._errored_llm_run(
error=error,
run_id=run_id,
)
tasks = [self._on_llm_error(llm_run), self._end_trace(llm_run)]
await asyncio.gather(*tasks)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
chain_run = self._create_chain_run(
serialized=serialized,
inputs=inputs,
run_id=run_id,
tags=tags,
parent_run_id=parent_run_id,
metadata=metadata,
run_type=run_type,
name=name,
**kwargs,
)
tasks = [self._start_trace(chain_run), self._on_chain_start(chain_run)]
await asyncio.gather(*tasks)
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
chain_run = self._complete_chain_run(
outputs=outputs,
run_id=run_id,
inputs=inputs,
**kwargs,
)
tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)]
await asyncio.gather(*tasks)
async def on_chain_error(
self,
error: BaseException,
*,
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> None:
chain_run = self._errored_chain_run(
error=error,
inputs=inputs,
run_id=run_id,
**kwargs,
)
tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)]
await asyncio.gather(*tasks)
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
tool_run = self._create_tool_run(
serialized=serialized,
input_str=input_str,
run_id=run_id,
tags=tags,
parent_run_id=parent_run_id,
metadata=metadata,
inputs=inputs,
**kwargs,
)
tasks = [self._start_trace(tool_run), self._on_tool_start(tool_run)]
await asyncio.gather(*tasks)
async def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
**kwargs: Any,
) -> None:
tool_run = self._complete_tool_run(
output=output,
run_id=run_id,
**kwargs,
)
tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)]
await asyncio.gather(*tasks)
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
tool_run = self._errored_tool_run(
error=error,
run_id=run_id,
)
tasks = [self._end_trace(tool_run), self._on_tool_error(tool_run)]
await asyncio.gather(*tasks)
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
retriever_run = self._create_retrieval_run(
serialized=serialized,
query=query,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
name=name,
)
tasks = [
self._start_trace(retriever_run),
self._on_retriever_start(retriever_run),
]
await asyncio.gather(*tasks)
async def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
retrieval_run = self._errored_retrieval_run(
error=error,
run_id=run_id,
**kwargs,
)
tasks = [
self._end_trace(retrieval_run),
self._on_retriever_error(retrieval_run),
]
await asyncio.gather(*tasks)
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
retrieval_run = self._complete_retrieval_run(
documents=documents,
run_id=run_id,
**kwargs,
)
tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)]
await asyncio.gather(*tasks)
async def _on_run_create(self, run: Run) -> None:
"""Process a run upon creation."""
pass
def _on_run_update(self, run: Run) -> None:
async def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
def _on_llm_start(self, run: Run) -> None:
async def _on_llm_start(self, run: Run) -> None:
"""Process the LLM Run upon start."""
def _on_llm_new_token(
async def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
async def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
async def _on_llm_new_token(
self,
run: Run,
token: str,
@ -577,38 +707,32 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> None:
"""Process new LLM token."""
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
def _on_chain_start(self, run: Run) -> None:
async def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
def _on_chain_end(self, run: Run) -> None:
async def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
def _on_chain_error(self, run: Run) -> None:
async def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
def _on_tool_start(self, run: Run) -> None:
async def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
def _on_tool_end(self, run: Run) -> None:
async def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
def _on_tool_error(self, run: Run) -> None:
async def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
def _on_chat_model_start(self, run: Run) -> None:
async def _on_chat_model_start(self, run: Run) -> None:
"""Process the Chat Model Run upon start."""
def _on_retriever_start(self, run: Run) -> None:
async def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
def _on_retriever_end(self, run: Run) -> None:
async def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
def _on_retriever_error(self, run: Run) -> None:
async def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""

@ -0,0 +1,566 @@
"""Utilities for the root listener."""
from __future__ import annotations
import logging
import sys
import traceback
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from uuid import UUID
from tenacity import RetryCallState
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.messages import BaseMessage
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from langchain_core.documents import Document
logger = logging.getLogger(__name__)
SCHEMA_FORMAT_TYPE = Literal["original", "streaming_events"]
class _TracerCore(ABC):
"""
Abstract base class for tracers
This class provides common methods, and reusable methods for tracers.
"""
def __init__(
self,
*,
_schema_format: Literal[
"original", "streaming_events", "original+chat"
] = "original",
**kwargs: Any,
) -> None:
"""Initialize the tracer.
Args:
_schema_format: Primarily changes how the inputs and outputs are
handled. For internal use only. This API will change.
- 'original' is the format used by all current tracers.
This format is slightly inconsistent with respect to inputs
and outputs.
- 'streaming_events' is used for supporting streaming events,
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
- 'original+chat' is a format that is the same as 'original'
except it does NOT raise an attribute error on_chat_model_start
kwargs: Additional keyword arguments that will be passed to
the super class.
"""
super().__init__(**kwargs)
self._schema_format = _schema_format # For internal use only API will change.
self.run_map: Dict[str, Run] = {}
"""Map of run ID to run. Cleared on run end."""
self.order_map: Dict[UUID, Tuple[UUID, str]] = {}
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
@abstractmethod
def _persist_run(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Persist a run."""
@staticmethod
def _add_child_run(
parent_run: Run,
child_run: Run,
) -> None:
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
@staticmethod
def _get_stacktrace(error: BaseException) -> str:
"""Get the stacktrace of the parent error."""
msg = repr(error)
try:
if sys.version_info < (3, 10):
tb = traceback.format_exception(
error.__class__, error, error.__traceback__
)
else:
tb = traceback.format_exception(error)
return (msg + "\n\n".join(tb)).strip()
except: # noqa: E722
return msg
def _start_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # type: ignore[return]
current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id)
if run.parent_run_id:
if parent := self.order_map.get(run.parent_run_id):
run.trace_id, run.dotted_order = parent
run.dotted_order += "." + current_dotted_order
if parent_run := self.run_map.get(str(run.parent_run_id)):
self._add_child_run(parent_run, run)
else:
logger.warning(
f"Parent run {run.parent_run_id} not found for run {run.id}."
" Treating as a root run."
)
run.parent_run_id = None
run.trace_id = run.id
run.dotted_order = current_dotted_order
else:
run.trace_id = run.id
run.dotted_order = current_dotted_order
self.order_map[run.id] = (run.trace_id, run.dotted_order)
self.run_map[str(run.id)] = run
def _get_run(
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
) -> Run:
try:
run = self.run_map[str(run_id)]
except KeyError as exc:
raise TracerException(f"No indexed run ID {run_id}.") from exc
if isinstance(run_type, str):
run_types: Union[Set[str], None] = {run_type}
else:
run_types = run_type
if run_types is not None and run.run_type not in run_types:
raise TracerException(
f"Found {run.run_type} run at ID {run_id}, "
f"but expected {run_types} run."
)
return run
def _create_chat_model_run(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Create a chat model run."""
if self._schema_format not in ("streaming_events", "original+chat"):
# Please keep this un-implemented for backwards compatibility.
# When it's unimplemented old tracers that use the "original" format
# fallback on the on_llm_start method implementation if they
# find that the on_chat_model_start method is not implemented.
# This can eventually be cleaned up by writing a "modern" tracer
# that has all the updated schema changes corresponding to
# the "streaming_events" format.
raise NotImplementedError(
f"Chat model tracing is not supported in "
f"for {self._schema_format} format."
)
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
return Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
# WARNING: This is valid ONLY for streaming_events.
# run_type="llm" is what's used by virtually all tracers.
# Changing this to "chat_model" may break triggering on_llm_start
run_type="chat_model",
tags=tags,
name=name, # type: ignore[arg-type]
)
def _create_llm_run(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Create a llm run"""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
return Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
# TODO: Figure out how to expose kwargs here
inputs={"prompts": prompts},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
run_type="llm",
tags=tags or [],
name=name, # type: ignore[arg-type]
)
def _llm_run_with_token_event(
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""
Append token event to LLM run and return the run
"""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append(
{
"name": "new_token",
"time": datetime.now(timezone.utc),
"kwargs": event_kwargs,
},
)
return llm_run
def _llm_run_with_retry_event(
self,
retry_state: RetryCallState,
run_id: UUID,
**kwargs: Any,
) -> Run:
llm_run = self._get_run(run_id)
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
if retry_state.outcome is None:
retry_d["outcome"] = "N/A"
elif retry_state.outcome.failed:
retry_d["outcome"] = "failed"
exception = retry_state.outcome.exception()
retry_d["exception"] = str(exception)
retry_d["exception_type"] = exception.__class__.__name__
else:
retry_d["outcome"] = "success"
retry_d["result"] = str(retry_state.outcome.result())
llm_run.events.append(
{
"name": "retry",
"time": datetime.now(timezone.utc),
"kwargs": retry_d,
},
)
return llm_run
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]
if "message" in output_generation:
output_generation["message"] = dumpd(
cast(ChatGeneration, generation).message
)
llm_run.end_time = datetime.now(timezone.utc)
llm_run.events.append({"name": "end", "time": llm_run.end_time})
return llm_run
def _errored_llm_run(self, error: BaseException, run_id: UUID) -> Run:
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.error = self._get_stacktrace(error)
llm_run.end_time = datetime.now(timezone.utc)
llm_run.events.append({"name": "error", "time": llm_run.end_time})
return llm_run
def _create_chain_run(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Create a chain Run"""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
return Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs=self._get_chain_inputs(inputs),
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
child_runs=[],
run_type=run_type or "chain",
name=name, # type: ignore[arg-type]
tags=tags or [],
)
def _get_chain_inputs(self, inputs: Any) -> Any:
"""Get the inputs for a chain run."""
if self._schema_format in ("original", "original+chat"):
return inputs if isinstance(inputs, dict) else {"input": inputs}
elif self._schema_format == "streaming_events":
return {
"input": inputs,
}
else:
raise ValueError(f"Invalid format: {self._schema_format}")
def _get_chain_outputs(self, outputs: Any) -> Any:
"""Get the outputs for a chain run."""
if self._schema_format in ("original", "original+chat"):
return outputs if isinstance(outputs, dict) else {"output": outputs}
elif self._schema_format == "streaming_events":
return {
"output": outputs,
}
else:
raise ValueError(f"Invalid format: {self._schema_format}")
def _complete_chain_run(
self,
outputs: Dict[str, Any],
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Update a chain run with outputs and end time."""
chain_run = self._get_run(run_id)
chain_run.outputs = self._get_chain_outputs(outputs)
chain_run.end_time = datetime.now(timezone.utc)
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = self._get_chain_inputs(inputs)
return chain_run
def _errored_chain_run(
self,
error: BaseException,
inputs: Optional[Dict[str, Any]],
run_id: UUID,
**kwargs: Any,
) -> Run:
chain_run = self._get_run(run_id)
chain_run.error = self._get_stacktrace(error)
chain_run.end_time = datetime.now(timezone.utc)
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = self._get_chain_inputs(inputs)
return chain_run
def _create_tool_run(
self,
serialized: Dict[str, Any],
input_str: str,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Create a tool run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
if self._schema_format in ("original", "original+chat"):
inputs = {"input": input_str}
elif self._schema_format == "streaming_events":
inputs = {"input": inputs}
else:
raise AssertionError(f"Invalid format: {self._schema_format}")
return Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
# Wrapping in dict since Run requires a dict object.
inputs=inputs,
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
child_runs=[],
run_type="tool",
tags=tags or [],
name=name, # type: ignore[arg-type]
)
def _complete_tool_run(
self,
output: Dict[str, Any],
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Update a tool run with outputs and end time."""
tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.now(timezone.utc)
tool_run.events.append({"name": "end", "time": tool_run.end_time})
return tool_run
def _errored_tool_run(
self,
error: BaseException,
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Update a tool run with error and end time."""
tool_run = self._get_run(run_id, run_type="tool")
tool_run.error = self._get_stacktrace(error)
tool_run.end_time = datetime.now(timezone.utc)
tool_run.events.append({"name": "error", "time": tool_run.end_time})
return tool_run
def _create_retrieval_run(
self,
serialized: Dict[str, Any],
query: str,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Create a retrieval run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
return Run(
id=run_id,
name=name or "Retriever",
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"query": query},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
tags=tags,
child_runs=[],
run_type="retriever",
)
def _complete_retrieval_run(
self,
documents: Sequence[Document],
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Update a retrieval run with outputs and end time."""
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.now(timezone.utc)
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
return retrieval_run
def _errored_retrieval_run(
self,
error: BaseException,
run_id: UUID,
**kwargs: Any,
) -> Run:
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = self._get_stacktrace(error)
retrieval_run.end_time = datetime.now(timezone.utc)
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
return retrieval_run
def __deepcopy__(self, memo: dict) -> _TracerCore:
"""Deepcopy the tracer."""
return self
def __copy__(self) -> _TracerCore:
"""Copy the tracer."""
return self
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""End a trace for a run."""
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon creation."""
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon update."""
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon start."""
def _on_llm_new_token(
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token."""
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run."""
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon error."""
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon start."""
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run."""
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon error."""
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon start."""
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run."""
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon error."""
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chat Model Run upon start."""
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon start."""
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run."""
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon error."""

@ -1,14 +1,18 @@
from typing import Callable, Optional, Union
from typing import Awaitable, Callable, Optional, Union
from uuid import UUID
from langchain_core.runnables.config import (
RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
)
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.base import AsyncBaseTracer, BaseTracer
from langchain_core.tracers.schemas import Run
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
AsyncListener = Union[
Callable[[Run], Awaitable[None]], Callable[[Run, RunnableConfig], Awaitable[None]]
]
class RootListenersTracer(BaseTracer):
@ -54,3 +58,50 @@ class RootListenersTracer(BaseTracer):
else:
if self._arg_on_error is not None:
call_func_with_variable_args(self._arg_on_error, run, self.config)
class AsyncRootListenersTracer(AsyncBaseTracer):
"""Async Tracer that calls listeners on run start, end, and error."""
def __init__(
self,
*,
config: RunnableConfig,
on_start: Optional[AsyncListener],
on_end: Optional[AsyncListener],
on_error: Optional[AsyncListener],
) -> None:
super().__init__()
self.config = config
self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
self.root_id: Optional[UUID] = None
async def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass
async def _on_run_create(self, run: Run) -> None:
if self.root_id is not None:
return
self.root_id = run.id
if self._arg_on_start is not None:
await acall_func_with_variable_args(self._arg_on_start, run, self.config)
async def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
return
if run.error is None:
if self._arg_on_end is not None:
await acall_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
await acall_func_with_variable_args(
self._arg_on_error, run, self.config
)

@ -0,0 +1,598 @@
"""Test Tracer classes."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, List
from uuid import uuid4
import pytest
from freezegun import freeze_time
from langchain_core.callbacks import AsyncCallbackManager
from langchain_core.exceptions import TracerException
from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult
from langchain_core.tracers.base import AsyncBaseTracer
from langchain_core.tracers.schemas import Run
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
class FakeAsyncTracer(AsyncBaseTracer):
"""Fake tracer to test async based tracers."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
async def _persist_run(self, run: Run) -> None:
self.runs.append(run)
def _compare_run_with_error(run: Any, expected_run: Any) -> None:
if run.child_runs:
assert len(expected_run.child_runs) == len(run.child_runs)
for received, expected in zip(run.child_runs, expected_run.child_runs):
_compare_run_with_error(received, expected)
received = run.dict(exclude={"child_runs"})
received_err = received.pop("error")
expected = expected_run.dict(exclude={"child_runs"})
expected_err = expected.pop("error")
assert received == expected
if expected_err is not None:
assert received_err is not None
assert expected_err in received_err
else:
assert received_err is None
@freeze_time("2023-01-01")
async def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=uuid,
parent_run_id=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
async def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeAsyncTracer()
manager = AsyncCallbackManager(handlers=[tracer])
run_managers = await manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = Run(
id=str(run_managers[0].run_id), # type: ignore[arg-type]
name="chat_model",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=run_managers[0].run_id,
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
)
for run_manager in run_managers:
await run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
async def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeAsyncTracer()
with pytest.raises(TracerException):
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
@freeze_time("2023-01-01")
async def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = uuid4()
compare_run = Run(
id=uuid,
name="llm",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
num_runs = 10
for _ in range(num_runs):
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
async def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "chain"},
inputs={},
outputs={},
error=None,
run_type="chain",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
await tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
await tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
async def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "tool"},
inputs={"input": "test"},
outputs={"output": "test"},
error=None,
run_type="tool",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
await tracer.on_tool_start(
serialized={"name": "tool"}, input_str="test", run_id=uuid
)
await tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
async def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeAsyncTracer()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
for _ in range(10):
await tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
await tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
await tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
await tracer.on_tool_end("test", run_id=tool_uuid)
await tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
await tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = Run( # type: ignore[call-arg]
id=str(chain_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "chain"},
inputs={},
outputs={},
run_type="chain",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}",
child_runs=[
Run( # type: ignore[call-arg]
id=tool_uuid,
parent_run_id=chain_uuid,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "tool"},
inputs=dict(input="test"),
outputs=dict(output="test"),
error=None,
run_type="tool",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
child_runs=[
Run( # type: ignore[call-arg]
id=str(llm_uuid1), # type: ignore[arg-type]
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}.20230101T000000000000Z{llm_uuid1}",
)
],
),
Run( # type: ignore[call-arg]
id=str(llm_uuid2), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
),
],
)
assert tracer.runs[0] == compare_run
assert tracer.runs == [compare_run] * 10
@freeze_time("2023-01-01")
async def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=None,
error=repr(exception),
run_type="llm",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
await tracer.on_llm_error(exception, run_id=uuid)
assert len(tracer.runs) == 1
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
async def test_tracer_llm_run_on_error_callback() -> None:
"""Test tracer on an LLM run with an error and a callback."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=None,
error=repr(exception),
run_type="llm",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
class FakeTracerWithLlmErrorCallback(FakeAsyncTracer):
error_run = None
async def _on_llm_error(self, run: Run) -> None:
self.error_run = run
tracer = FakeTracerWithLlmErrorCallback()
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
await tracer.on_llm_error(exception, run_id=uuid)
_compare_run_with_error(tracer.error_run, compare_run)
@freeze_time("2023-01-01")
async def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "chain"},
inputs={},
outputs=None,
error=repr(exception),
run_type="chain",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
await tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
await tracer.on_chain_error(exception, run_id=uuid)
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
async def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "tool"},
inputs=dict(input="test"),
outputs=None,
action="{'name': 'tool'}",
error=repr(exception),
run_type="tool",
trace_id=uuid,
dotted_order=f"20230101T000000000000Z{uuid}",
)
tracer = FakeAsyncTracer()
await tracer.on_tool_start(
serialized={"name": "tool"}, input_str="test", run_id=uuid
)
await tracer.on_tool_error(exception, run_id=uuid)
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
async def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeAsyncTracer()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
await tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
await tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
await tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
await tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
await tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
await tracer.on_llm_error(exception, run_id=llm_uuid3)
await tracer.on_tool_error(exception, run_id=tool_uuid)
await tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = Run( # type: ignore[call-arg]
id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "chain"},
error=repr(exception),
inputs={},
outputs=None,
run_type="chain",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}",
child_runs=[
Run( # type: ignore[call-arg]
id=str(llm_uuid1), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid1}",
),
Run( # type: ignore[call-arg]
id=str(llm_uuid2), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "end", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
),
Run( # type: ignore[call-arg]
id=str(tool_uuid), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized={"name": "tool"},
error=repr(exception),
inputs=dict(input="test"),
outputs=None,
action="{'name': 'tool'}",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
child_runs=[
Run( # type: ignore[call-arg]
id=str(llm_uuid3), # type: ignore[arg-type]
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
{"name": "start", "time": datetime.now(timezone.utc)},
{"name": "error", "time": datetime.now(timezone.utc)},
],
extra={},
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
outputs=None,
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}.20230101T000000000000Z{llm_uuid3}",
)
],
run_type="tool",
),
],
)
assert len(tracer.runs) == 3
for run in tracer.runs:
_compare_run_with_error(run, compare_run)

@ -13,10 +13,11 @@ from freezegun import freeze_time
from langsmith import Client, traceable
from langchain_core.callbacks import CallbackManager
from langchain_core.exceptions import TracerException
from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult
from langchain_core.runnables import chain as as_runnable
from langchain_core.tracers.base import BaseTracer, TracerException
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
SERIALIZED = {"id": ["llm"]}

Loading…
Cancel
Save