From 1671c2afb2fbc03207ca9d544d0a28e8ad338859 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Tue, 30 May 2023 18:47:06 -0700 Subject: [PATCH] py tracer fixes (#5377) --- docs/tracing/agent_with_tracing.ipynb | 12 +- langchain/callbacks/tracers/langchain.py | 184 +++++++++++++++--- langchain/callbacks/tracers/schemas.py | 13 +- .../callbacks/test_langchain_tracer.py | 5 +- .../callbacks/tracers/test_tracer.py | 134 ------------- 5 files changed, 181 insertions(+), 167 deletions(-) delete mode 100644 tests/unit_tests/callbacks/tracers/test_tracer.py diff --git a/docs/tracing/agent_with_tracing.ipynb b/docs/tracing/agent_with_tracing.ipynb index ce269d82..f8ee0e47 100644 --- a/docs/tracing/agent_with_tracing.ipynb +++ b/docs/tracing/agent_with_tracing.ipynb @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "id": "87027b0d-3a61-47cf-8a65-3002968be7f9", "metadata": { "tags": [] @@ -356,13 +356,13 @@ "source": [ "import os\n", "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", - "# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchainpro-api-gateway-12bfv6cf.uc.gateway.dev\" # Uncomment this line if you want to use the hosted version\n", + "# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.langchain.plus\" # Uncomment this line if you want to use the hosted version\n", "# os.environ[\"LANGCHAIN_API_KEY\"] = \"\" # Uncomment this line if you want to use the hosted version." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 8, "id": "5b4f49a2-7d09-4601-a8ba-976f0517c64c", "metadata": { "tags": [] @@ -379,7 +379,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "id": "029b4a57-dc49-49de-8f03-53c292144e09", "metadata": { "tags": [] @@ -397,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "id": "91a85fb2-6027-4bd0-b1fe-2a3b3b79e2dd", "metadata": { "tags": [] @@ -426,7 +426,7 @@ "'1.0891804557407723'" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index c580f059..cda5e8d8 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -3,24 +3,35 @@ from __future__ import annotations import logging import os +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any, Dict, List, Optional from uuid import UUID import requests -from tenacity import retry, stop_after_attempt, wait_fixed +from requests.exceptions import HTTPError +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import ( Run, RunCreate, RunTypeEnum, + RunUpdate, TracerSession, TracerSessionCreate, ) from langchain.schema import BaseMessage, messages_to_dict from langchain.utils import raise_for_status_with_text +logger = logging.getLogger(__name__) + def get_headers() -> Dict[str, Any]: """Get the headers for the LangChain API.""" @@ -34,7 +45,27 @@ def get_endpoint() -> str: return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984") -@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) +class LangChainTracerAPIError(Exception): + """An error occurred while communicating with the LangChain API.""" + + +class LangChainTracerUserError(Exception): + """An error occurred while communicating with the LangChain API.""" + + +class LangChainTracerError(Exception): + """An error occurred while communicating with the LangChain API.""" + + +retry_decorator = retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(LangChainTracerAPIError), + before_sleep=before_sleep_log(logger, logging.WARNING), +) + + +@retry_decorator def _get_tenant_id( tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict] ) -> str: @@ -44,8 +75,24 @@ def _get_tenant_id( return tenant_id_ endpoint_ = endpoint or get_endpoint() headers_ = headers or get_headers() - response = requests.get(endpoint_ + "/tenants", headers=headers_) - raise_for_status_with_text(response) + response = None + try: + response = requests.get(endpoint_ + "/tenants", headers=headers_) + raise_for_status_with_text(response) + except HTTPError as e: + if response is not None and response.status_code == 500: + raise LangChainTracerAPIError( + f"Failed to get tenant ID from LangChain API. {e}" + ) + else: + raise LangChainTracerUserError( + f"Failed to get tenant ID from LangChain API. {e}" + ) + except Exception as e: + raise LangChainTracerError( + f"Failed to get tenant ID from LangChain API. {e}" + ) from e + tenants: List[Dict[str, Any]] = response.json() if not tenants: raise ValueError(f"No tenants found for URL {endpoint_}") @@ -72,6 +119,8 @@ class LangChainTracer(BaseTracer): self.example_id = example_id self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default") self.session_extra = session_extra + # set max_workers to 1 to process tasks in order + self.executor = ThreadPoolExecutor(max_workers=1) def on_chat_model_start( self, @@ -108,7 +157,7 @@ class LangChainTracer(BaseTracer): self.tenant_id = tenant_id return tenant_id - @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) + @retry_decorator def ensure_session(self) -> TracerSession: """Upsert a session.""" if self.session is not None: @@ -118,37 +167,124 @@ class LangChainTracer(BaseTracer): session_create = TracerSessionCreate( name=self.session_name, extra=self.session_extra, tenant_id=tenant_id ) - r = requests.post( - url, - data=session_create.json(), - headers=self._headers, - ) - raise_for_status_with_text(r) - self.session = TracerSession(**r.json()) + response = None + try: + response = requests.post( + url, + data=session_create.json(), + headers=self._headers, + ) + response.raise_for_status() + except HTTPError as e: + if response is not None and response.status_code == 500: + raise LangChainTracerAPIError( + f"Failed to upsert session to LangChain API. {e}" + ) + else: + raise LangChainTracerUserError( + f"Failed to upsert session to LangChain API. {e}" + ) + except Exception as e: + raise LangChainTracerError( + f"Failed to upsert session to LangChain API. {e}" + ) from e + self.session = TracerSession(**response.json()) return self.session - def _persist_run_nested(self, run: Run) -> None: + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + + @retry_decorator + def _persist_run_single(self, run: Run) -> None: """Persist a run.""" session = self.ensure_session() - child_runs = run.child_runs + if run.parent_run_id is None: + run.reference_example_id = self.example_id run_dict = run.dict() del run_dict["child_runs"] run_create = RunCreate(**run_dict, session_id=session.id) + response = None try: response = requests.post( f"{self._endpoint}/runs", data=run_create.json(), headers=self._headers, ) - raise_for_status_with_text(response) + response.raise_for_status() + except HTTPError as e: + if response is not None and response.status_code == 500: + raise LangChainTracerAPIError( + f"Failed to upsert persist run to LangChain API. {e}" + ) + else: + raise LangChainTracerUserError( + f"Failed to persist run to LangChain API. {e}" + ) except Exception as e: - logging.warning(f"Failed to persist run: {e}") - for child_run in child_runs: - child_run.parent_run_id = run.id - self._persist_run_nested(child_run) + raise LangChainTracerError( + f"Failed to persist run to LangChain API. {e}" + ) from e - def _persist_run(self, run: Run) -> None: - """Persist a run.""" - run.reference_example_id = self.example_id - # TODO: Post first then patch - self._persist_run_nested(run) + @retry_decorator + def _update_run_single(self, run: Run) -> None: + """Update a run.""" + run_update = RunUpdate(**run.dict()) + response = None + try: + response = requests.patch( + f"{self._endpoint}/runs/{run.id}", + data=run_update.json(), + headers=self._headers, + ) + response.raise_for_status() + except HTTPError as e: + if response is not None and response.status_code == 500: + raise LangChainTracerAPIError( + f"Failed to update run to LangChain API. {e}" + ) + else: + raise LangChainTracerUserError(f"Failed to run to LangChain API. {e}") + except Exception as e: + raise LangChainTracerError( + f"Failed to update run to LangChain API. {e}" + ) from e + + def _on_llm_start(self, run: Run) -> None: + """Persist an LLM run.""" + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + + def _on_chat_model_start(self, run: Run) -> None: + """Persist an LLM run.""" + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + + def _on_llm_end(self, run: Run) -> None: + """Process the LLM Run.""" + self.executor.submit(self._update_run_single, run.copy(deep=True)) + + def _on_llm_error(self, run: Run) -> None: + """Process the LLM Run upon error.""" + self.executor.submit(self._update_run_single, run.copy(deep=True)) + + def _on_chain_start(self, run: Run) -> None: + """Process the Chain Run upon start.""" + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + + def _on_chain_end(self, run: Run) -> None: + """Process the Chain Run.""" + self.executor.submit(self._update_run_single, run.copy(deep=True)) + + def _on_chain_error(self, run: Run) -> None: + """Process the Chain Run upon error.""" + self.executor.submit(self._update_run_single, run.copy(deep=True)) + + def _on_tool_start(self, run: Run) -> None: + """Process the Tool Run upon start.""" + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + + def _on_tool_end(self, run: Run) -> None: + """Process the Tool Run.""" + self.executor.submit(self._update_run_single, run.copy(deep=True)) + + def _on_tool_error(self, run: Run) -> None: + """Process the Tool Run upon error.""" + self.executor.submit(self._update_run_single, run.copy(deep=True)) diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index a1252ac5..8c5a7d92 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -91,6 +91,9 @@ class ToolRun(BaseRun): child_tool_runs: List[ToolRun] = Field(default_factory=list) +# Begin V2 API Schemas + + class RunTypeEnum(str, Enum): """Enum for run types.""" @@ -105,7 +108,7 @@ class RunBase(BaseModel): id: Optional[UUID] start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - extra: dict + extra: Optional[Dict[str, Any]] = None error: Optional[str] execution_order: int child_execution_order: Optional[int] @@ -144,5 +147,13 @@ class RunCreate(RunBase): return values +class RunUpdate(BaseModel): + end_time: Optional[datetime.datetime] + error: Optional[str] + outputs: Optional[dict] + parent_run_id: Optional[UUID] + reference_example_id: Optional[UUID] + + ChainRun.update_forward_refs() ToolRun.update_forward_refs() diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index 781da6f7..1caf5ebc 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -8,6 +8,7 @@ from aiohttp import ClientSession from langchain.agents import AgentType, initialize_agent, load_tools from langchain.callbacks import tracing_enabled from langchain.callbacks.manager import tracing_v2_enabled +from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI questions = [ @@ -140,10 +141,10 @@ async def test_tracing_v2_environment_variable() -> None: def test_tracing_v2_context_manager() -> None: - llm = OpenAI(temperature=0) + llm = ChatOpenAI(temperature=0) tools = load_tools(["llm-math", "serpapi"], llm=llm) agent = initialize_agent( - tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True ) if "LANGCHAIN_TRACING_V2" in os.environ: del os.environ["LANGCHAIN_TRACING_V2"] diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py deleted file mode 100644 index 055ac640..00000000 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Test Tracer classes.""" -from __future__ import annotations - -import json -from datetime import datetime -from typing import Tuple -from unittest.mock import patch -from uuid import UUID, uuid4 - -import pytest -from freezegun import freeze_time - -from langchain.callbacks.tracers.langchain import LangChainTracer -from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession -from langchain.schema import LLMResult - -_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a") -_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad") - - -@pytest.fixture -def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer: - monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") - monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") - monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") - tracer = LangChainTracer() - return tracer - - -# Mock a sample TracerSession object -@pytest.fixture -def sample_tracer_session_v2() -> TracerSession: - return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID) - - -@freeze_time("2023-01-01") -@pytest.fixture -def sample_runs() -> Tuple[Run, Run, Run]: - llm_run = Run( - id="57a08cc4-73d2-4236-8370-549099d07fad", - name="llm_run", - execution_order=1, - child_execution_order=1, - parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad", - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - session_id=1, - inputs={"prompts": []}, - outputs=LLMResult(generations=[[]]).dict(), - serialized={}, - extra={}, - run_type=RunTypeEnum.llm, - ) - chain_run = Run( - id="57a08cc4-73d2-4236-8371-549099d07fad", - name="chain_run", - execution_order=1, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - child_execution_order=1, - serialized={}, - inputs={}, - outputs={}, - child_runs=[llm_run], - extra={}, - run_type=RunTypeEnum.chain, - ) - - tool_run = Run( - id="57a08cc4-73d2-4236-8372-549099d07fad", - name="tool_run", - execution_order=1, - child_execution_order=1, - inputs={"input": "test"}, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - outputs=None, - serialized={}, - child_runs=[], - extra={}, - run_type=RunTypeEnum.tool, - ) - return llm_run, chain_run, tool_run - - -def test_persist_run( - lang_chain_tracer_v2: LangChainTracer, - sample_tracer_session_v2: TracerSession, - sample_runs: Tuple[Run, Run, Run], -) -> None: - """Test that persist_run method calls requests.post once per method call.""" - with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( - "langchain.callbacks.tracers.langchain.requests.get" - ) as get: - post.return_value.raise_for_status.return_value = None - lang_chain_tracer_v2.session = sample_tracer_session_v2 - for run in sample_runs: - lang_chain_tracer_v2.run_map[str(run.id)] = run - for run in sample_runs: - lang_chain_tracer_v2._end_trace(run) - - assert post.call_count == 3 - assert get.call_count == 0 - - -def test_persist_run_with_example_id( - lang_chain_tracer_v2: LangChainTracer, - sample_tracer_session_v2: TracerSession, - sample_runs: Tuple[Run, Run, Run], -) -> None: - """Test the example ID is assigned only to the parent run and not the children.""" - example_id = uuid4() - llm_run, chain_run, tool_run = sample_runs - chain_run.child_runs = [tool_run] - tool_run.child_runs = [llm_run] - with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( - "langchain.callbacks.tracers.langchain.requests.get" - ) as get: - post.return_value.raise_for_status.return_value = None - lang_chain_tracer_v2.session = sample_tracer_session_v2 - lang_chain_tracer_v2.example_id = example_id - lang_chain_tracer_v2._persist_run(chain_run) - - assert post.call_count == 3 - assert get.call_count == 0 - posted_data = [ - json.loads(call_args[1]["data"]) for call_args in post.call_args_list - ] - assert posted_data[0]["id"] == str(chain_run.id) - assert posted_data[0]["reference_example_id"] == str(example_id) - assert posted_data[1]["id"] == str(tool_run.id) - assert not posted_data[1].get("reference_example_id") - assert posted_data[2]["id"] == str(llm_run.id) - assert not posted_data[2].get("reference_example_id")