v2 tracer with single runs endpoint (#3951)

This commit is contained in:
Ankush Gola 2023-05-01 22:41:32 -07:00 committed by GitHub
parent 8fcb56e74a
commit 3bd5a99b83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 115 additions and 3 deletions

View File

@ -4,6 +4,7 @@ import asyncio
import copy
import functools
import os
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
@ -20,7 +21,7 @@ from langchain.callbacks.base import (
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@ -46,7 +47,7 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
"""Get OpenAI callback handler in a context manager."""
"""Get Tracer in a context manager."""
cb = LangChainTracer()
session = cb.load_session(session_name)
tracing_callback_var.set(cb)
@ -54,6 +55,23 @@ def tracing_enabled(
tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
warnings.warn(
"The experimental tracing v2 is in development. "
"This is not yet stable and may change in the future."
)
cb = LangChainTracerV2()
session = cb.load_session(session_name)
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
def _handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import requests
@ -11,6 +11,7 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
Run,
ToolRun,
TracerSession,
TracerSessionCreate,
@ -87,3 +88,68 @@ class LangChainTracer(BaseTracer):
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")
class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@staticmethod
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
inputs: Dict[str, Any] = {}
outputs: Optional[Dict[str, Any]] = None
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
if isinstance(run, LLMRun):
run_type = "llm"
inputs = {"prompts": run.prompts}
outputs = run.response.dict() if run.response else {}
child_runs = []
elif isinstance(run, ChainRun):
run_type = "chain"
inputs = run.inputs
outputs = run.outputs
child_runs = [
*run.child_llm_runs,
*run.child_chain_runs,
*run.child_tool_runs,
]
else:
run_type = "tool"
inputs = {"input": run.tool_input}
outputs = {"output": run.output} if run.output else {}
child_runs = [
*run.child_llm_runs,
*run.child_chain_runs,
*run.child_tool_runs,
]
return Run(
id=run.uuid,
name=run.serialized.get("name"),
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
error=run.error,
execution_order=run.execution_order,
serialized=run.serialized,
inputs=inputs,
outputs=outputs,
session_id=run.session_id,
run_type=run_type,
parent_run_id=run.parent_uuid,
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
run_create = self._convert_run(run)
try:
requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")

View File

@ -2,7 +2,9 @@
from __future__ import annotations
import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
@ -72,5 +74,31 @@ class ToolRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list)
class RunTypeEnum(str, Enum):
"""Enum for run types."""
tool = "tool"
chain = "chain"
llm = "llm"
class Run(BaseModel):
id: Optional[UUID]
name: str
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict
error: Optional[str]
execution_order: int
serialized: dict
inputs: dict
outputs: Optional[dict]
session_id: int
parent_run_id: Optional[UUID]
example_id: Optional[UUID]
run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list)
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()