v2 tracer with single runs endpoint (#3951)

fix_agent_callbacks
Ankush Gola 1 year ago committed by GitHub
parent 8fcb56e74a
commit 3bd5a99b83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ import asyncio
import copy import copy
import functools import functools
import os import os
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
@ -20,7 +21,7 @@ from langchain.callbacks.base import (
from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@ -46,7 +47,7 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
def tracing_enabled( def tracing_enabled(
session_name: str = "default", session_name: str = "default",
) -> Generator[TracerSession, None, None]: ) -> Generator[TracerSession, None, None]:
"""Get OpenAI callback handler in a context manager.""" """Get Tracer in a context manager."""
cb = LangChainTracer() cb = LangChainTracer()
session = cb.load_session(session_name) session = cb.load_session(session_name)
tracing_callback_var.set(cb) tracing_callback_var.set(cb)
@ -54,6 +55,23 @@ def tracing_enabled(
tracing_callback_var.set(None) tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
warnings.warn(
"The experimental tracing v2 is in development. "
"This is not yet stable and may change in the future."
)
cb = LangChainTracerV2()
session = cb.load_session(session_name)
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
def _handle_event( def _handle_event(
handlers: List[BaseCallbackHandler], handlers: List[BaseCallbackHandler],
event_name: str, event_name: str,

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any, Dict, Optional, Union from typing import Any, Dict, List, Optional, Union
import requests import requests
@ -11,6 +11,7 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import ( from langchain.callbacks.tracers.schemas import (
ChainRun, ChainRun,
LLMRun, LLMRun,
Run,
ToolRun, ToolRun,
TracerSession, TracerSession,
TracerSessionCreate, TracerSessionCreate,
@ -87,3 +88,68 @@ class LangChainTracer(BaseTracer):
def load_default_session(self) -> TracerSession: def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session.""" """Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default") return self._load_session("default")
class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@staticmethod
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
inputs: Dict[str, Any] = {}
outputs: Optional[Dict[str, Any]] = None
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
if isinstance(run, LLMRun):
run_type = "llm"
inputs = {"prompts": run.prompts}
outputs = run.response.dict() if run.response else {}
child_runs = []
elif isinstance(run, ChainRun):
run_type = "chain"
inputs = run.inputs
outputs = run.outputs
child_runs = [
*run.child_llm_runs,
*run.child_chain_runs,
*run.child_tool_runs,
]
else:
run_type = "tool"
inputs = {"input": run.tool_input}
outputs = {"output": run.output} if run.output else {}
child_runs = [
*run.child_llm_runs,
*run.child_chain_runs,
*run.child_tool_runs,
]
return Run(
id=run.uuid,
name=run.serialized.get("name"),
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
error=run.error,
execution_order=run.execution_order,
serialized=run.serialized,
inputs=inputs,
outputs=outputs,
session_id=run.session_id,
run_type=run_type,
parent_run_id=run.parent_uuid,
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
run_create = self._convert_run(run)
try:
requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")

@ -2,7 +2,9 @@
from __future__ import annotations from __future__ import annotations
import datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -72,5 +74,31 @@ class ToolRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list) child_tool_runs: List[ToolRun] = Field(default_factory=list)
class RunTypeEnum(str, Enum):
"""Enum for run types."""
tool = "tool"
chain = "chain"
llm = "llm"
class Run(BaseModel):
id: Optional[UUID]
name: str
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict
error: Optional[str]
execution_order: int
serialized: dict
inputs: dict
outputs: Optional[dict]
session_id: int
parent_run_id: Optional[UUID]
example_id: Optional[UUID]
run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list)
ChainRun.update_forward_refs() ChainRun.update_forward_refs()
ToolRun.update_forward_refs() ToolRun.update_forward_refs()

Loading…
Cancel
Save