You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/langchain_core/tracers/root_listeners.py

57 lines
1.7 KiB
Python

from typing import Callable, Optional, Union
from uuid import UUID
from langchain_core.runnables.config import (
RunnableConfig,
call_func_with_variable_args,
)
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
class RootListenersTracer(BaseTracer):
"""A tracer that calls listeners on run start, end, and error."""
def __init__(
self,
*,
config: RunnableConfig,
on_start: Optional[Listener],
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
super().__init__()
self.config = config
self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
self.root_id: Optional[UUID] = None
def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass
def _on_run_create(self, run: Run) -> None:
if self.root_id is not None:
return
self.root_id = run.id
if self._arg_on_start is not None:
call_func_with_variable_args(self._arg_on_start, run, self.config)
def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
return
if run.error is None:
if self._arg_on_end is not None:
call_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
call_func_with_variable_args(self._arg_on_error, run, self.config)