From 3bd5a99b835fa320d02aa733cb0c0bc4a87724fa Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Mon, 1 May 2023 22:41:32 -0700 Subject: [PATCH] v2 tracer with single runs endpoint (#3951) --- langchain/callbacks/manager.py | 22 +++++++- langchain/callbacks/tracers/langchain.py | 68 +++++++++++++++++++++++- langchain/callbacks/tracers/schemas.py | 28 ++++++++++ 3 files changed, 115 insertions(+), 3 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index a21847a5..2a2e7b6e 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -4,6 +4,7 @@ import asyncio import copy import functools import os +import warnings from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union @@ -20,7 +21,7 @@ from langchain.callbacks.base import ( from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.tracers.base import TracerSession -from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2 from langchain.schema import AgentAction, AgentFinish, LLMResult Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] @@ -46,7 +47,7 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: def tracing_enabled( session_name: str = "default", ) -> Generator[TracerSession, None, None]: - """Get OpenAI callback handler in a context manager.""" + """Get Tracer in a context manager.""" cb = LangChainTracer() session = cb.load_session(session_name) tracing_callback_var.set(cb) @@ -54,6 +55,23 @@ def tracing_enabled( tracing_callback_var.set(None) +@contextmanager +def tracing_v2_enabled( + session_name: str = "default", +) -> Generator[TracerSession, None, None]: + """Get the experimental tracer handler in a context manager.""" + # Issue a warning that this is experimental + warnings.warn( + "The experimental tracing v2 is in development. " + "This is not yet stable and may change in the future." + ) + cb = LangChainTracerV2() + session = cb.load_session(session_name) + tracing_callback_var.set(cb) + yield session + tracing_callback_var.set(None) + + def _handle_event( handlers: List[BaseCallbackHandler], event_name: str, diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 80e7d2d2..3d9a6b0d 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import requests @@ -11,6 +11,7 @@ from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import ( ChainRun, LLMRun, + Run, ToolRun, TracerSession, TracerSessionCreate, @@ -87,3 +88,68 @@ class LangChainTracer(BaseTracer): def load_default_session(self) -> TracerSession: """Load the default tracing session and set it as the Tracer's session.""" return self._load_session("default") + + +class LangChainTracerV2(LangChainTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + @staticmethod + def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run: + """Convert a run to a Run.""" + + inputs: Dict[str, Any] = {} + outputs: Optional[Dict[str, Any]] = None + child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] + if isinstance(run, LLMRun): + run_type = "llm" + inputs = {"prompts": run.prompts} + outputs = run.response.dict() if run.response else {} + child_runs = [] + elif isinstance(run, ChainRun): + run_type = "chain" + inputs = run.inputs + outputs = run.outputs + child_runs = [ + *run.child_llm_runs, + *run.child_chain_runs, + *run.child_tool_runs, + ] + else: + run_type = "tool" + inputs = {"input": run.tool_input} + outputs = {"output": run.output} if run.output else {} + child_runs = [ + *run.child_llm_runs, + *run.child_chain_runs, + *run.child_tool_runs, + ] + + return Run( + id=run.uuid, + name=run.serialized.get("name"), + start_time=run.start_time, + end_time=run.end_time, + extra=run.extra, + error=run.error, + execution_order=run.execution_order, + serialized=run.serialized, + inputs=inputs, + outputs=outputs, + session_id=run.session_id, + run_type=run_type, + parent_run_id=run.parent_uuid, + child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs], + ) + + def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + run_create = self._convert_run(run) + + try: + requests.post( + f"{self._endpoint}/runs", + data=run_create.json(), + headers=self._headers, + ) + except Exception as e: + logging.warning(f"Failed to persist run: {e}") diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index ce6368ff..e863bd04 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -2,7 +2,9 @@ from __future__ import annotations import datetime +from enum import Enum from typing import Any, Dict, List, Optional +from uuid import UUID from pydantic import BaseModel, Field @@ -72,5 +74,31 @@ class ToolRun(BaseRun): child_tool_runs: List[ToolRun] = Field(default_factory=list) +class RunTypeEnum(str, Enum): + """Enum for run types.""" + + tool = "tool" + chain = "chain" + llm = "llm" + + +class Run(BaseModel): + id: Optional[UUID] + name: str + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + extra: dict + error: Optional[str] + execution_order: int + serialized: dict + inputs: dict + outputs: Optional[dict] + session_id: int + parent_run_id: Optional[UUID] + example_id: Optional[UUID] + run_type: RunTypeEnum + child_runs: List[Run] = Field(default_factory=list) + + ChainRun.update_forward_refs() ToolRun.update_forward_refs()