py tracer fixes (#5377)

searx_updates
Ankush Gola 12 months ago committed by GitHub
parent ce8b7a2a69
commit 1671c2afb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -347,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"id": "87027b0d-3a61-47cf-8a65-3002968be7f9",
"metadata": {
"tags": []
@ -356,13 +356,13 @@
"source": [
"import os\n",
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchainpro-api-gateway-12bfv6cf.uc.gateway.dev\" # Uncomment this line if you want to use the hosted version\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.langchain.plus\" # Uncomment this line if you want to use the hosted version\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = \"<YOUR-LANGCHAINPLUS-API-KEY>\" # Uncomment this line if you want to use the hosted version."
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 8,
"id": "5b4f49a2-7d09-4601-a8ba-976f0517c64c",
"metadata": {
"tags": []
@ -379,7 +379,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 9,
"id": "029b4a57-dc49-49de-8f03-53c292144e09",
"metadata": {
"tags": []
@ -397,7 +397,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 10,
"id": "91a85fb2-6027-4bd0-b1fe-2a3b3b79e2dd",
"metadata": {
"tags": []
@ -426,7 +426,7 @@
"'1.0891804557407723'"
]
},
"execution_count": 15,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}

@ -3,24 +3,35 @@ from __future__ import annotations
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
import requests
from tenacity import retry, stop_after_attempt, wait_fixed
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
Run,
RunCreate,
RunTypeEnum,
RunUpdate,
TracerSession,
TracerSessionCreate,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text
logger = logging.getLogger(__name__)
def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
@ -34,7 +45,27 @@ def get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
class LangChainTracerAPIError(Exception):
"""An error occurred while communicating with the LangChain API."""
class LangChainTracerUserError(Exception):
"""An error occurred while communicating with the LangChain API."""
class LangChainTracerError(Exception):
"""An error occurred while communicating with the LangChain API."""
retry_decorator = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(LangChainTracerAPIError),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
@retry_decorator
def _get_tenant_id(
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
) -> str:
@ -44,8 +75,24 @@ def _get_tenant_id(
return tenant_id_
endpoint_ = endpoint or get_endpoint()
headers_ = headers or get_headers()
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
response = None
try:
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to get tenant ID from LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to get tenant ID from LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to get tenant ID from LangChain API. {e}"
) from e
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint_}")
@ -72,6 +119,8 @@ class LangChainTracer(BaseTracer):
self.example_id = example_id
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
self.session_extra = session_extra
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)
def on_chat_model_start(
self,
@ -108,7 +157,7 @@ class LangChainTracer(BaseTracer):
self.tenant_id = tenant_id
return tenant_id
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
@retry_decorator
def ensure_session(self) -> TracerSession:
"""Upsert a session."""
if self.session is not None:
@ -118,37 +167,124 @@ class LangChainTracer(BaseTracer):
session_create = TracerSessionCreate(
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
)
r = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
raise_for_status_with_text(r)
self.session = TracerSession(**r.json())
response = None
try:
response = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert session to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to upsert session to LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to upsert session to LangChain API. {e}"
) from e
self.session = TracerSession(**response.json())
return self.session
def _persist_run_nested(self, run: Run) -> None:
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
@retry_decorator
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
session = self.ensure_session()
child_runs = run.child_runs
if run.parent_run_id is None:
run.reference_example_id = self.example_id
run_dict = run.dict()
del run_dict["child_runs"]
run_create = RunCreate(**run_dict, session_id=session.id)
response = None
try:
response = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert persist run to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to persist run to LangChain API. {e}"
)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
for child_run in child_runs:
child_run.parent_run_id = run.id
self._persist_run_nested(child_run)
raise LangChainTracerError(
f"Failed to persist run to LangChain API. {e}"
) from e
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
self._persist_run_nested(run)
@retry_decorator
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
run_update = RunUpdate(**run.dict())
response = None
try:
response = requests.patch(
f"{self._endpoint}/runs/{run.id}",
data=run_update.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to update run to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(f"Failed to run to LangChain API. {e}")
except Exception as e:
raise LangChainTracerError(
f"Failed to update run to LangChain API. {e}"
) from e
def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

@ -91,6 +91,9 @@ class ToolRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list)
# Begin V2 API Schemas
class RunTypeEnum(str, Enum):
"""Enum for run types."""
@ -105,7 +108,7 @@ class RunBase(BaseModel):
id: Optional[UUID]
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict
extra: Optional[Dict[str, Any]] = None
error: Optional[str]
execution_order: int
child_execution_order: Optional[int]
@ -144,5 +147,13 @@ class RunCreate(RunBase):
return values
class RunUpdate(BaseModel):
end_time: Optional[datetime.datetime]
error: Optional[str]
outputs: Optional[dict]
parent_run_id: Optional[UUID]
reference_example_id: Optional[UUID]
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()

@ -8,6 +8,7 @@ from aiohttp import ClientSession
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
questions = [
@ -140,10 +141,10 @@ async def test_tracing_v2_environment_variable() -> None:
def test_tracing_v2_context_manager() -> None:
llm = OpenAI(temperature=0)
llm = ChatOpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING_V2" in os.environ:
del os.environ["LANGCHAIN_TRACING_V2"]

@ -1,134 +0,0 @@
"""Test Tracer classes."""
from __future__ import annotations
import json
from datetime import datetime
from typing import Tuple
from unittest.mock import patch
from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.schema import LLMResult
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracer()
return tracer
# Mock a sample TracerSession object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSession:
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
@freeze_time("2023-01-01")
@pytest.fixture
def sample_runs() -> Tuple[Run, Run, Run]:
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=1,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type=RunTypeEnum.llm,
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
child_execution_order=1,
serialized={},
inputs={},
outputs={},
child_runs=[llm_run],
extra={},
run_type=RunTypeEnum.chain,
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
inputs={"input": "test"},
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
outputs=None,
serialized={},
child_runs=[],
extra={},
run_type=RunTypeEnum.tool,
)
return llm_run, chain_run, tool_run
def test_persist_run(
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test that persist_run method calls requests.post once per method call."""
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
for run in sample_runs:
lang_chain_tracer_v2.run_map[str(run.id)] = run
for run in sample_runs:
lang_chain_tracer_v2._end_trace(run)
assert post.call_count == 3
assert get.call_count == 0
def test_persist_run_with_example_id(
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test the example ID is assigned only to the parent run and not the children."""
example_id = uuid4()
llm_run, chain_run, tool_run = sample_runs
chain_run.child_runs = [tool_run]
tool_run.child_runs = [llm_run]
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
lang_chain_tracer_v2.example_id = example_id
lang_chain_tracer_v2._persist_run(chain_run)
assert post.call_count == 3
assert get.call_count == 0
posted_data = [
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
]
assert posted_data[0]["id"] == str(chain_run.id)
assert posted_data[0]["reference_example_id"] == str(example_id)
assert posted_data[1]["id"] == str(tool_run.id)
assert not posted_data[1].get("reference_example_id")
assert posted_data[2]["id"] == str(llm_run.id)
assert not posted_data[2].get("reference_example_id")
Loading…
Cancel
Save