forked from Archives/langchain
add tracing support to langchain (#741)
* add implementations of `BaseCallbackHandler` to support tracing: `SharedTracer` which is thread-safe and `Tracer` which is not and is meant to be used locally. * Tracers persist runs to locally running `langchain-server` Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>ankush/async-llmchain
parent
7f76a1189c
commit
57609845df
@ -0,0 +1,12 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain.callbacks.tracers.base import SharedTracer, Tracer
|
||||
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
|
||||
|
||||
|
||||
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
|
||||
"""Shared tracer that records LangChain execution to LangChain endpoint."""
|
||||
|
||||
|
||||
class LangChainTracer(Tracer, BaseLangChainTracer):
|
||||
"""Tracer that records LangChain execution to LangChain endpoint."""
|
@ -0,0 +1,334 @@
|
||||
"""Base interfaces for tracing runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.shared import Singleton
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
"""Base class for exceptions in tracers module."""
|
||||
|
||||
|
||||
class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
@abstractmethod
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
@abstractmethod
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
|
||||
@abstractmethod
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
|
||||
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
|
||||
"""NOT thread safe, do not call this method from multiple threads."""
|
||||
session_create = TracerSessionCreate(name=name, extra=kwargs)
|
||||
session = self._persist_session(session_create)
|
||||
self._session = session
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session and set it as the Tracer's session."""
|
||||
|
||||
@abstractmethod
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
|
||||
@_execution_order.setter
|
||||
@abstractmethod
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
|
||||
@_session.setter
|
||||
@abstractmethod
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
|
||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Start a trace for a run."""
|
||||
self._execution_order += 1
|
||||
|
||||
if self._stack:
|
||||
if not (
|
||||
isinstance(self._stack[-1], ChainRun)
|
||||
or isinstance(self._stack[-1], ToolRun)
|
||||
):
|
||||
raise TracerException(
|
||||
f"Nested {run.__class__.__name__} can only be"
|
||||
f" logged inside a ChainRun or ToolRun"
|
||||
)
|
||||
self._add_child_run(self._stack[-1], run)
|
||||
self._stack.append(run)
|
||||
|
||||
def _end_trace(self) -> None:
|
||||
"""End a trace for a run."""
|
||||
run = self._stack.pop()
|
||||
if not self._stack:
|
||||
self._execution_order = 1
|
||||
self._persist_run(run)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
|
||||
llm_run = LLMRun(
|
||||
serialized=serialized,
|
||||
prompts=prompts,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].response = response
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
self._stack[-1].error = repr(error)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
|
||||
chain_run = ChainRun(
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
child_runs=[],
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].outputs = outputs
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Handle an error for a chain run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
|
||||
tool_run = ToolRun(
|
||||
serialized=serialized,
|
||||
action=action.tool,
|
||||
tool_input=action.tool_input,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
child_runs=[],
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""End a trace for a tool run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].output = output
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Handle an error for a tool run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Handle a text message."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Handle an agent finish message."""
|
||||
pass
|
||||
|
||||
|
||||
class Tracer(BaseTracer, ABC):
|
||||
"""A non-thread safe implementation of the BaseTracer interface."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a tracer."""
|
||||
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
self._tracer_execution_order = 1
|
||||
self._tracer_session: Optional[TracerSession] = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TracerStack(threading.local):
|
||||
"""A stack of runs used for logging."""
|
||||
|
||||
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
|
||||
execution_order: int = 1
|
||||
|
||||
|
||||
class SharedTracer(Singleton, BaseTracer, ABC):
|
||||
"""A thread-safe Singleton implementation of BaseTracer."""
|
||||
|
||||
_tracer_stack = TracerStack()
|
||||
_tracer_session = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack.stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_stack.execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_stack.execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
with self._lock:
|
||||
# TODO: currently, we are only checking current thread's stack.
|
||||
# Need to make sure that we are not in the middle of a trace
|
||||
# in any thread.
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
@ -0,0 +1,112 @@
|
||||
"""A Tracer implementation that records to LangChain endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
|
||||
|
||||
class BaseLangChainTracer(BaseTracer, ABC):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
always_verbose: bool = True
|
||||
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
if isinstance(run, LLMRun):
|
||||
endpoint = f"{self._endpoint}/llm-runs"
|
||||
elif isinstance(run, ChainRun):
|
||||
endpoint = f"{self._endpoint}/chain-runs"
|
||||
else:
|
||||
endpoint = f"{self._endpoint}/tool-runs"
|
||||
|
||||
try:
|
||||
requests.post(
|
||||
endpoint,
|
||||
data=run.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
||||
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{self._endpoint}/sessions",
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
session = TracerSession(id=r.json()["id"], **session_create.dict())
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to create session, using default session: {e}")
|
||||
session = TracerSession(id=1, **session_create.dict())
|
||||
return session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a session from the tracer."""
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions?name={session_name}",
|
||||
headers=self._headers,
|
||||
)
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to load session {session_name}, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions",
|
||||
headers=self._headers,
|
||||
)
|
||||
# Use the first session result
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to default session, using empty session: {e}")
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
else:
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
@ -0,0 +1,76 @@
|
||||
"""Schemas for tracers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class TracerSessionBase(BaseModel):
|
||||
"""Base class for TracerSession."""
|
||||
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
name: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TracerSessionCreate(TracerSessionBase):
|
||||
"""Create class for TracerSession."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TracerSession(TracerSessionBase):
|
||||
"""TracerSession schema."""
|
||||
|
||||
id: int
|
||||
|
||||
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
id: Optional[Union[int, str]] = None
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
execution_order: int
|
||||
serialized: Dict[str, Any]
|
||||
session_id: int
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
response: Optional[LLMResult] = None
|
||||
|
||||
|
||||
class ChainRun(BaseRun):
|
||||
"""Class for ChainRun."""
|
||||
|
||||
inputs: Dict[str, Any]
|
||||
outputs: Optional[Dict[str, Any]] = None
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolRun(BaseRun):
|
||||
"""Class for ToolRun."""
|
||||
|
||||
tool_input: str
|
||||
output: Optional[str] = None
|
||||
action: str
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
ToolRun.update_forward_refs()
|
@ -0,0 +1,29 @@
|
||||
version: '3'
|
||||
services:
|
||||
langchain-frontend:
|
||||
image: notlangchain/langchainplus-frontend:latest
|
||||
ports:
|
||||
- 4173:4173
|
||||
environment:
|
||||
- BACKEND_URL=http://langchain-backend:8000
|
||||
- PUBLIC_BASE_URL=http://localhost:8000
|
||||
- PUBLIC_DEV_MODE=true
|
||||
depends_on:
|
||||
- langchain-backend
|
||||
langchain-backend:
|
||||
image: notlangchain/langchainplus:latest
|
||||
environment:
|
||||
- PORT=8000
|
||||
- LANGCHAIN_ENV=local
|
||||
ports:
|
||||
- 8000:8000
|
||||
depends_on:
|
||||
- langchain-db
|
||||
langchain-db:
|
||||
image: postgres:14.1
|
||||
environment:
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_DB=postgres
|
||||
ports:
|
||||
- 5432:5432
|
@ -0,0 +1,14 @@
|
||||
"""Script to run langchain-server locally using docker-compose."""
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the langchain server locally."""
|
||||
p = Path(__file__).absolute().parent / "docker-compose.yaml"
|
||||
subprocess.run(["docker-compose", "-f", str(p), "pull"])
|
||||
subprocess.run(["docker-compose", "-f", str(p), "up"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
"""Tests for correct functioning of tracers."""
|
@ -0,0 +1,530 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.tracers.base import (
|
||||
BaseTracer,
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
SharedTracer,
|
||||
ToolRun,
|
||||
Tracer,
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
return ChainRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_runs=[
|
||||
ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_chain_end(outputs={})
|
||||
|
||||
|
||||
def _add_child_run(
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
|
||||
def _generate_id() -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
||||
|
||||
|
||||
def load_default_session() -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
||||
|
||||
|
||||
class FakeTracer(Tracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
class FakeSharedTracer(SharedTracer):
|
||||
"""Fake shared tracer that records LangChain execution."""
|
||||
|
||||
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
with self._lock:
|
||||
self.runs.append(run)
|
||||
|
||||
def remove_runs(self) -> None:
|
||||
"""Remove all runs."""
|
||||
with self._lock:
|
||||
self.runs = []
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_session() -> None:
|
||||
"""Test tracer on an LLM run without a session."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_end(outputs={})
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_end("test")
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_chain_error(exception)
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.new_session()
|
||||
tracer.remove_runs()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run_multithreaded() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.remove_runs()
|
||||
tracer.new_session()
|
||||
threads = []
|
||||
num_threads = 10
|
||||
for _ in range(num_threads):
|
||||
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert tracer.runs == [_get_compare_run()] * num_threads
|
Loading…
Reference in New Issue