From a30f42da4e5702c4c2f97cd63176e863bcb6b1a9 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Fri, 5 May 2023 14:55:01 -0700 Subject: [PATCH] Update V2 Tracer (#4193) - Update the RunCreate object to work with recent changes - Add optional Example ID to the tracer - Adjust default persist_session behavior to attempt to load the session if it exists - Raise more useful HTTP errors for logging - Add unit testing - Fix the default ID to be a UUID for v2 tracer sessions Broken out from the big draft here: https://github.com/hwchase17/langchain/pull/4061 --- langchain/callbacks/manager.py | 7 +- langchain/callbacks/tracers/base.py | 3 +- langchain/callbacks/tracers/langchain.py | 47 ++++++++---- langchain/callbacks/tracers/schemas.py | 25 +++++-- langchain/utils.py | 10 +++ .../callbacks/tracers/test_tracer.py | 75 +++++++++++++++---- 6 files changed, 129 insertions(+), 38 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 5fa2b3dd..cd03de43 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -58,6 +58,7 @@ def tracing_enabled( @contextmanager def tracing_v2_enabled( session_name: str = "default", + example_id: Optional[Union[str, UUID]] = None, ) -> Generator[TracerSessionV2, None, None]: """Get the experimental tracer handler in a context manager.""" # Issue a warning that this is experimental @@ -65,8 +66,10 @@ def tracing_v2_enabled( "The experimental tracing v2 is in development. " "This is not yet stable and may change in the future." ) - cb = LangChainTracerV2() - session = cb.load_session(session_name) + if isinstance(example_id, str): + example_id = UUID(example_id) + cb = LangChainTracerV2(example_id=example_id) + session = cast(TracerSessionV2, cb.new_session(session_name)) tracing_callback_var.set(cb) yield session tracing_callback_var.set(None) diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 6d036d32..4516f8b5 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -29,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} - self.session: Optional[Union[TracerSessionV2, TracerSession]] = None + self.session: Optional[Union[TracerSession, TracerSessionV2]] = None @staticmethod def _add_child_run( @@ -165,7 +165,6 @@ class BaseTracer(BaseCallbackHandler, ABC): llm_run = self.run_map.get(run_id_) if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") - llm_run.response = response llm_run.end_time = datetime.utcnow() self._end_trace(llm_run) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 767d1538..41c6446f 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import os from typing import Any, Dict, List, Optional, Union +from uuid import UUID, uuid4 import requests @@ -11,13 +12,14 @@ from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import ( ChainRun, LLMRun, - Run, + RunCreate, ToolRun, TracerSession, TracerSessionBase, TracerSessionV2, TracerSessionV2Create, ) +from langchain.utils import raise_for_status_with_text def _get_headers() -> Dict[str, Any]: @@ -51,11 +53,12 @@ class LangChainTracer(BaseTracer): endpoint = f"{self._endpoint}/tool-runs" try: - requests.post( + response = requests.post( endpoint, data=run.json(), headers=self._headers, ) + raise_for_status_with_text(response) except Exception as e: logging.warning(f"Failed to persist run: {e}") @@ -111,7 +114,7 @@ def _get_tenant_id() -> Optional[str]: endpoint = _get_endpoint() headers = _get_headers() response = requests.get(endpoint + "/tenants", headers=headers) - response.raise_for_status() + raise_for_status_with_text(response) tenants: List[Dict[str, Any]] = response.json() if not tenants: raise ValueError(f"No tenants found for URL {endpoint}") @@ -121,12 +124,13 @@ def _get_tenant_id() -> Optional[str]: class LangChainTracerV2(LangChainTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - def __init__(self, **kwargs: Any) -> None: + def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None: """Initialize the LangChain tracer.""" super().__init__(**kwargs) self._endpoint = _get_endpoint() self._headers = _get_headers() self.tenant_id = _get_tenant_id() + self.example_id = example_id def _get_session_create( self, name: Optional[str] = None, **kwargs: Any @@ -135,16 +139,30 @@ class LangChainTracerV2(LangChainTracer): def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2: """Persist a session.""" + session: Optional[TracerSessionV2] = None try: r = requests.post( f"{self._endpoint}/sessions", data=session_create.json(), headers=self._headers, ) - session = TracerSessionV2(id=r.json()["id"], **session_create.dict()) + raise_for_status_with_text(r) + creation_args = session_create.dict() + if "id" in creation_args: + del creation_args["id"] + return TracerSessionV2(id=r.json()["id"], **creation_args) except Exception as e: - logging.warning(f"Failed to create session, using default session: {e}") - session = self.load_session("default") + if session_create.name is not None: + try: + return self.load_session(session_create.name) + except Exception: + pass + logging.warning( + f"Failed to create session {session_create.name}," + f" using empty session: {e}" + ) + session = TracerSessionV2(id=uuid4(), **session_create.dict()) + return session def _get_default_query_params(self) -> Dict[str, Any]: @@ -159,13 +177,14 @@ class LangChainTracerV2(LangChainTracer): if session_name: params["name"] = session_name r = requests.get(url, headers=self._headers, params=params) + raise_for_status_with_text(r) tracer_session = TracerSessionV2(**r.json()[0]) except Exception as e: session_type = "default" if not session_name else session_name logging.warning( f"Failed to load {session_type} session, using empty session: {e}" ) - tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id) + tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id) self.session = tracer_session return tracer_session @@ -174,7 +193,7 @@ class LangChainTracerV2(LangChainTracer): """Load the default tracing session and set it as the Tracer's session.""" return self.load_session("default") - def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run: + def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate: """Convert a run to a Run.""" session = self.session or self.load_default_session() inputs: Dict[str, Any] = {} @@ -204,9 +223,9 @@ class LangChainTracerV2(LangChainTracer): *run.child_tool_runs, ] - return Run( + return RunCreate( id=run.uuid, - name=run.serialized.get("name", f"{run_type}-{run.uuid}"), + name=run.serialized.get("name"), start_time=run.start_time, end_time=run.end_time, extra=run.extra or {}, @@ -217,7 +236,7 @@ class LangChainTracerV2(LangChainTracer): outputs=outputs, session_id=session.id, run_type=run_type, - parent_run_id=run.parent_uuid, + reference_example_id=self.example_id, child_runs=[self._convert_run(child) for child in child_runs], ) @@ -225,11 +244,11 @@ class LangChainTracerV2(LangChainTracer): """Persist a run.""" run_create = self._convert_run(run) try: - result = requests.post( + response = requests.post( f"{self._endpoint}/runs", data=run_create.json(), headers=self._headers, ) - result.raise_for_status() + raise_for_status_with_text(response) except Exception as e: logging.warning(f"Failed to persist run: {e}") diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index 016ef93d..f38094ae 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -37,9 +37,11 @@ class TracerSessionV2Base(TracerSessionBase): tenant_id: UUID -class TracerSessionV2Create(TracerSessionBase): +class TracerSessionV2Create(TracerSessionV2Base): """A creation class for TracerSessionV2.""" + id: Optional[UUID] + pass @@ -100,9 +102,10 @@ class RunTypeEnum(str, Enum): llm = "llm" -class Run(BaseModel): +class RunBase(BaseModel): + """Base Run schema.""" + id: Optional[UUID] - name: str start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) extra: dict @@ -112,10 +115,22 @@ class Run(BaseModel): inputs: dict outputs: Optional[dict] session_id: UUID - parent_run_id: Optional[UUID] reference_example_id: Optional[UUID] run_type: RunTypeEnum - child_runs: List[Run] = Field(default_factory=list) + + +class RunCreate(RunBase): + """Schema to create a run in the DB.""" + + name: Optional[str] + child_runs: List[RunCreate] = Field(default_factory=list) + + +class Run(RunBase): + """Run schema when loading from the DB.""" + + name: str + parent_run_id: Optional[UUID] ChainRun.update_forward_refs() diff --git a/langchain/utils.py b/langchain/utils.py index 1ce2c373..7420b371 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -2,6 +2,8 @@ import os from typing import Any, Callable, Dict, Optional, Tuple +from requests import HTTPError, Response + def get_from_dict_or_env( data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None @@ -52,6 +54,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: return decorator +def raise_for_status_with_text(response: Response) -> None: + """Raise an error with the response text.""" + try: + response.raise_for_status() + except HTTPError as e: + raise ValueError(response.text) from e + + def stringify_value(val: Any) -> str: if isinstance(val, str): return val diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 49d83f08..0791885b 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -18,7 +18,12 @@ from langchain.callbacks.tracers.base import ( TracerSession, ) from langchain.callbacks.tracers.langchain import LangChainTracerV2 -from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2 +from langchain.callbacks.tracers.schemas import ( + RunCreate, + TracerSessionBase, + TracerSessionV2, + TracerSessionV2Create, +) from langchain.schema import LLMResult TEST_SESSION_ID = 2023 @@ -541,14 +546,12 @@ def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]: return llm_run, chain_run, tool_run -# Test _get_default_query_params method def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None: expected = {"tenant_id": "test-tenant-id"} result = lang_chain_tracer_v2._get_default_query_params() assert result == expected -# Test load_session method @patch("langchain.callbacks.tracers.langchain.requests.get") def test_load_session( mock_requests_get: Mock, @@ -577,23 +580,65 @@ def test_convert_run( converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run) converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run) - assert isinstance(converted_llm_run, Run) - assert isinstance(converted_chain_run, Run) - assert isinstance(converted_tool_run, Run) + assert isinstance(converted_llm_run, RunCreate) + assert isinstance(converted_chain_run, RunCreate) + assert isinstance(converted_tool_run, RunCreate) -@patch("langchain.callbacks.tracers.langchain.requests.post") def test_persist_run( - mock_requests_post: Mock, lang_chain_tracer_v2: LangChainTracerV2, sample_tracer_session_v2: TracerSessionV2, sample_runs: Tuple[LLMRun, ChainRun, ToolRun], ) -> None: - mock_requests_post.return_value.raise_for_status.return_value = None - lang_chain_tracer_v2.session = sample_tracer_session_v2 - llm_run, chain_run, tool_run = sample_runs - lang_chain_tracer_v2._persist_run(llm_run) - lang_chain_tracer_v2._persist_run(chain_run) - lang_chain_tracer_v2._persist_run(tool_run) + """Test that persist_run method calls requests.post once per method call.""" + with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( + "langchain.callbacks.tracers.langchain.requests.get" + ) as get: + post.return_value.raise_for_status.return_value = None + lang_chain_tracer_v2.session = sample_tracer_session_v2 + llm_run, chain_run, tool_run = sample_runs + lang_chain_tracer_v2._persist_run(llm_run) + lang_chain_tracer_v2._persist_run(chain_run) + lang_chain_tracer_v2._persist_run(tool_run) + + assert post.call_count == 3 + assert get.call_count == 0 + + +def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None: + """Test creating the 'SessionCreate' object.""" + lang_chain_tracer_v2.tenant_id = str(_TENANT_ID) + session_create = lang_chain_tracer_v2._get_session_create(name="test") + assert isinstance(session_create, TracerSessionV2Create) + assert session_create.name == "test" + assert session_create.tenant_id == _TENANT_ID + - assert mock_requests_post.call_count == 3 +@patch("langchain.callbacks.tracers.langchain.requests.post") +def test_persist_session( + mock_requests_post: Mock, + lang_chain_tracer_v2: LangChainTracerV2, + sample_tracer_session_v2: TracerSessionV2, +) -> None: + """Test persist_session returns a TracerSessionV2 with the updated ID.""" + session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict()) + new_id = str(uuid4()) + mock_requests_post.return_value.json.return_value = {"id": new_id} + result = lang_chain_tracer_v2._persist_session(session_create) + assert isinstance(result, TracerSessionV2) + res = sample_tracer_session_v2.dict() + res["id"] = UUID(new_id) + assert result.dict() == res + + +@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session") +def test_load_default_session( + mock_load_session: Mock, + lang_chain_tracer_v2: LangChainTracerV2, + sample_tracer_session_v2: TracerSessionV2, +) -> None: + """Test load_default_session attempts to load with the default name.""" + mock_load_session.return_value = sample_tracer_session_v2 + result = lang_chain_tracer_v2.load_default_session() + assert result == sample_tracer_session_v2 + mock_load_session.assert_called_with("default")