Update V2 Tracer (#4193)

- Update the RunCreate object to work with recent changes
- Add optional Example ID to the tracer
- Adjust default persist_session behavior to attempt to load the session
if it exists
- Raise more useful HTTP errors for logging
- Add unit testing
- Fix the default ID to be a UUID for v2 tracer sessions


Broken out from the big draft here:
https://github.com/hwchase17/langchain/pull/4061
parallel_dir_loader
Zander Chase 1 year ago committed by GitHub
parent c3044b1bf0
commit a30f42da4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -58,6 +58,7 @@ def tracing_enabled(
@contextmanager @contextmanager
def tracing_v2_enabled( def tracing_v2_enabled(
session_name: str = "default", session_name: str = "default",
example_id: Optional[Union[str, UUID]] = None,
) -> Generator[TracerSessionV2, None, None]: ) -> Generator[TracerSessionV2, None, None]:
"""Get the experimental tracer handler in a context manager.""" """Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental # Issue a warning that this is experimental
@ -65,8 +66,10 @@ def tracing_v2_enabled(
"The experimental tracing v2 is in development. " "The experimental tracing v2 is in development. "
"This is not yet stable and may change in the future." "This is not yet stable and may change in the future."
) )
cb = LangChainTracerV2() if isinstance(example_id, str):
session = cb.load_session(session_name) example_id = UUID(example_id)
cb = LangChainTracerV2(example_id=example_id)
session = cast(TracerSessionV2, cb.new_session(session_name))
tracing_callback_var.set(cb) tracing_callback_var.set(cb)
yield session yield session
tracing_callback_var.set(None) tracing_callback_var.set(None)

@ -29,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None self.session: Optional[Union[TracerSession, TracerSessionV2]] = None
@staticmethod @staticmethod
def _add_child_run( def _add_child_run(
@ -165,7 +165,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run = self.run_map.get(run_id_) llm_run = self.run_map.get(run_id_)
if llm_run is None or not isinstance(llm_run, LLMRun): if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced") raise TracerException("No LLMRun found to be traced")
llm_run.response = response llm_run.response = response
llm_run.end_time = datetime.utcnow() llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run) self._end_trace(llm_run)

@ -4,6 +4,7 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4
import requests import requests
@ -11,13 +12,14 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import ( from langchain.callbacks.tracers.schemas import (
ChainRun, ChainRun,
LLMRun, LLMRun,
Run, RunCreate,
ToolRun, ToolRun,
TracerSession, TracerSession,
TracerSessionBase, TracerSessionBase,
TracerSessionV2, TracerSessionV2,
TracerSessionV2Create, TracerSessionV2Create,
) )
from langchain.utils import raise_for_status_with_text
def _get_headers() -> Dict[str, Any]: def _get_headers() -> Dict[str, Any]:
@ -51,11 +53,12 @@ class LangChainTracer(BaseTracer):
endpoint = f"{self._endpoint}/tool-runs" endpoint = f"{self._endpoint}/tool-runs"
try: try:
requests.post( response = requests.post(
endpoint, endpoint,
data=run.json(), data=run.json(),
headers=self._headers, headers=self._headers,
) )
raise_for_status_with_text(response)
except Exception as e: except Exception as e:
logging.warning(f"Failed to persist run: {e}") logging.warning(f"Failed to persist run: {e}")
@ -111,7 +114,7 @@ def _get_tenant_id() -> Optional[str]:
endpoint = _get_endpoint() endpoint = _get_endpoint()
headers = _get_headers() headers = _get_headers()
response = requests.get(endpoint + "/tenants", headers=headers) response = requests.get(endpoint + "/tenants", headers=headers)
response.raise_for_status() raise_for_status_with_text(response)
tenants: List[Dict[str, Any]] = response.json() tenants: List[Dict[str, Any]] = response.json()
if not tenants: if not tenants:
raise ValueError(f"No tenants found for URL {endpoint}") raise ValueError(f"No tenants found for URL {endpoint}")
@ -121,12 +124,13 @@ def _get_tenant_id() -> Optional[str]:
class LangChainTracerV2(LangChainTracer): class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint.""" """An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None: def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""Initialize the LangChain tracer.""" """Initialize the LangChain tracer."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._endpoint = _get_endpoint() self._endpoint = _get_endpoint()
self._headers = _get_headers() self._headers = _get_headers()
self.tenant_id = _get_tenant_id() self.tenant_id = _get_tenant_id()
self.example_id = example_id
def _get_session_create( def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any self, name: Optional[str] = None, **kwargs: Any
@ -135,16 +139,30 @@ class LangChainTracerV2(LangChainTracer):
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2: def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
"""Persist a session.""" """Persist a session."""
session: Optional[TracerSessionV2] = None
try: try:
r = requests.post( r = requests.post(
f"{self._endpoint}/sessions", f"{self._endpoint}/sessions",
data=session_create.json(), data=session_create.json(),
headers=self._headers, headers=self._headers,
) )
session = TracerSessionV2(id=r.json()["id"], **session_create.dict()) raise_for_status_with_text(r)
creation_args = session_create.dict()
if "id" in creation_args:
del creation_args["id"]
return TracerSessionV2(id=r.json()["id"], **creation_args)
except Exception as e: except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}") if session_create.name is not None:
session = self.load_session("default") try:
return self.load_session(session_create.name)
except Exception:
pass
logging.warning(
f"Failed to create session {session_create.name},"
f" using empty session: {e}"
)
session = TracerSessionV2(id=uuid4(), **session_create.dict())
return session return session
def _get_default_query_params(self) -> Dict[str, Any]: def _get_default_query_params(self) -> Dict[str, Any]:
@ -159,13 +177,14 @@ class LangChainTracerV2(LangChainTracer):
if session_name: if session_name:
params["name"] = session_name params["name"] = session_name
r = requests.get(url, headers=self._headers, params=params) r = requests.get(url, headers=self._headers, params=params)
raise_for_status_with_text(r)
tracer_session = TracerSessionV2(**r.json()[0]) tracer_session = TracerSessionV2(**r.json()[0])
except Exception as e: except Exception as e:
session_type = "default" if not session_name else session_name session_type = "default" if not session_name else session_name
logging.warning( logging.warning(
f"Failed to load {session_type} session, using empty session: {e}" f"Failed to load {session_type} session, using empty session: {e}"
) )
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id) tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id)
self.session = tracer_session self.session = tracer_session
return tracer_session return tracer_session
@ -174,7 +193,7 @@ class LangChainTracerV2(LangChainTracer):
"""Load the default tracing session and set it as the Tracer's session.""" """Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default") return self.load_session("default")
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run: def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
"""Convert a run to a Run.""" """Convert a run to a Run."""
session = self.session or self.load_default_session() session = self.session or self.load_default_session()
inputs: Dict[str, Any] = {} inputs: Dict[str, Any] = {}
@ -204,9 +223,9 @@ class LangChainTracerV2(LangChainTracer):
*run.child_tool_runs, *run.child_tool_runs,
] ]
return Run( return RunCreate(
id=run.uuid, id=run.uuid,
name=run.serialized.get("name", f"{run_type}-{run.uuid}"), name=run.serialized.get("name"),
start_time=run.start_time, start_time=run.start_time,
end_time=run.end_time, end_time=run.end_time,
extra=run.extra or {}, extra=run.extra or {},
@ -217,7 +236,7 @@ class LangChainTracerV2(LangChainTracer):
outputs=outputs, outputs=outputs,
session_id=session.id, session_id=session.id,
run_type=run_type, run_type=run_type,
parent_run_id=run.parent_uuid, reference_example_id=self.example_id,
child_runs=[self._convert_run(child) for child in child_runs], child_runs=[self._convert_run(child) for child in child_runs],
) )
@ -225,11 +244,11 @@ class LangChainTracerV2(LangChainTracer):
"""Persist a run.""" """Persist a run."""
run_create = self._convert_run(run) run_create = self._convert_run(run)
try: try:
result = requests.post( response = requests.post(
f"{self._endpoint}/runs", f"{self._endpoint}/runs",
data=run_create.json(), data=run_create.json(),
headers=self._headers, headers=self._headers,
) )
result.raise_for_status() raise_for_status_with_text(response)
except Exception as e: except Exception as e:
logging.warning(f"Failed to persist run: {e}") logging.warning(f"Failed to persist run: {e}")

@ -37,9 +37,11 @@ class TracerSessionV2Base(TracerSessionBase):
tenant_id: UUID tenant_id: UUID
class TracerSessionV2Create(TracerSessionBase): class TracerSessionV2Create(TracerSessionV2Base):
"""A creation class for TracerSessionV2.""" """A creation class for TracerSessionV2."""
id: Optional[UUID]
pass pass
@ -100,9 +102,10 @@ class RunTypeEnum(str, Enum):
llm = "llm" llm = "llm"
class Run(BaseModel): class RunBase(BaseModel):
"""Base Run schema."""
id: Optional[UUID] id: Optional[UUID]
name: str
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict extra: dict
@ -112,10 +115,22 @@ class Run(BaseModel):
inputs: dict inputs: dict
outputs: Optional[dict] outputs: Optional[dict]
session_id: UUID session_id: UUID
parent_run_id: Optional[UUID]
reference_example_id: Optional[UUID] reference_example_id: Optional[UUID]
run_type: RunTypeEnum run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list)
class RunCreate(RunBase):
"""Schema to create a run in the DB."""
name: Optional[str]
child_runs: List[RunCreate] = Field(default_factory=list)
class Run(RunBase):
"""Run schema when loading from the DB."""
name: str
parent_run_id: Optional[UUID]
ChainRun.update_forward_refs() ChainRun.update_forward_refs()

@ -2,6 +2,8 @@
import os import os
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
from requests import HTTPError, Response
def get_from_dict_or_env( def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
@ -52,6 +54,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
return decorator return decorator
def raise_for_status_with_text(response: Response) -> None:
"""Raise an error with the response text."""
try:
response.raise_for_status()
except HTTPError as e:
raise ValueError(response.text) from e
def stringify_value(val: Any) -> str: def stringify_value(val: Any) -> str:
if isinstance(val, str): if isinstance(val, str):
return val return val

@ -18,7 +18,12 @@ from langchain.callbacks.tracers.base import (
TracerSession, TracerSession,
) )
from langchain.callbacks.tracers.langchain import LangChainTracerV2 from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2 from langchain.callbacks.tracers.schemas import (
RunCreate,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
)
from langchain.schema import LLMResult from langchain.schema import LLMResult
TEST_SESSION_ID = 2023 TEST_SESSION_ID = 2023
@ -541,14 +546,12 @@ def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
return llm_run, chain_run, tool_run return llm_run, chain_run, tool_run
# Test _get_default_query_params method
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None: def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
expected = {"tenant_id": "test-tenant-id"} expected = {"tenant_id": "test-tenant-id"}
result = lang_chain_tracer_v2._get_default_query_params() result = lang_chain_tracer_v2._get_default_query_params()
assert result == expected assert result == expected
# Test load_session method
@patch("langchain.callbacks.tracers.langchain.requests.get") @patch("langchain.callbacks.tracers.langchain.requests.get")
def test_load_session( def test_load_session(
mock_requests_get: Mock, mock_requests_get: Mock,
@ -577,23 +580,65 @@ def test_convert_run(
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run) converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run) converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
assert isinstance(converted_llm_run, Run) assert isinstance(converted_llm_run, RunCreate)
assert isinstance(converted_chain_run, Run) assert isinstance(converted_chain_run, RunCreate)
assert isinstance(converted_tool_run, Run) assert isinstance(converted_tool_run, RunCreate)
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_run( def test_persist_run(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2, lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2, sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun], sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None: ) -> None:
mock_requests_post.return_value.raise_for_status.return_value = None """Test that persist_run method calls requests.post once per method call."""
lang_chain_tracer_v2.session = sample_tracer_session_v2 with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
llm_run, chain_run, tool_run = sample_runs "langchain.callbacks.tracers.langchain.requests.get"
lang_chain_tracer_v2._persist_run(llm_run) ) as get:
lang_chain_tracer_v2._persist_run(chain_run) post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2._persist_run(tool_run) lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
assert post.call_count == 3
assert get.call_count == 0
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
"""Test creating the 'SessionCreate' object."""
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
session_create = lang_chain_tracer_v2._get_session_create(name="test")
assert isinstance(session_create, TracerSessionV2Create)
assert session_create.name == "test"
assert session_create.tenant_id == _TENANT_ID
assert mock_requests_post.call_count == 3 @patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_session(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
new_id = str(uuid4())
mock_requests_post.return_value.json.return_value = {"id": new_id}
result = lang_chain_tracer_v2._persist_session(session_create)
assert isinstance(result, TracerSessionV2)
res = sample_tracer_session_v2.dict()
res["id"] = UUID(new_id)
assert result.dict() == res
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
def test_load_default_session(
mock_load_session: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test load_default_session attempts to load with the default name."""
mock_load_session.return_value = sample_tracer_session_v2
result = lang_chain_tracer_v2.load_default_session()
assert result == sample_tracer_session_v2
mock_load_session.assert_called_with("default")

Loading…
Cancel
Save