Add Tenant ID to V2 Tracer (#4135)

Update the V2 tracer to
- use UUIDs instead of int's
- load a tenant ID and use that when saving sessions
parallel_dir_loader
Zander Chase 1 year ago committed by GitHub
parent fea639c1fc
commit 6032a051e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@ import os
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
from uuid import UUID, uuid4
from langchain.callbacks.base import (
@ -21,6 +21,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@ -28,7 +29,7 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
@ -48,7 +49,7 @@ def tracing_enabled(
) -> Generator[TracerSession, None, None]:
"""Get Tracer in a context manager."""
cb = LangChainTracer()
session = cb.load_session(session_name)
session = cast(TracerSession, cb.load_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
@ -57,7 +58,7 @@ def tracing_enabled(
@contextmanager
def tracing_v2_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
) -> Generator[TracerSessionV2, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
warnings.warn(

@ -12,7 +12,9 @@ from langchain.callbacks.tracers.schemas import (
LLMRun,
ToolRun,
TracerSession,
TracerSessionBase,
TracerSessionCreate,
TracerSessionV2,
)
from langchain.schema import LLMResult
@ -27,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[TracerSession] = None
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
@staticmethod
def _add_child_run(
@ -49,22 +51,31 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Persist a run."""
@abstractmethod
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
def _persist_session(
self, session: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a tracing session."""
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionCreate(name=name, extra=kwargs)
def new_session(
self, name: Optional[str] = None, **kwargs: Any
) -> Union[TracerSession, TracerSessionV2]:
"""NOT thread safe, do not call this method from multiple threads."""
session_create = TracerSessionCreate(name=name, extra=kwargs)
session_create = self._get_session_create(name=name, **kwargs)
session = self._persist_session(session_create)
self.session = session
return session
@abstractmethod
def load_session(self, session_name: str) -> TracerSession:
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a tracing session and set it as the Tracer's session."""
@abstractmethod
def load_default_session(self) -> TracerSession:
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:

@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import (
Run,
ToolRun,
TracerSession,
TracerSessionCreate,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
)
def _get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers
def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
self.session = self.load_session(session_name)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer):
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
def _persist_session(
self, session_create: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a session."""
try:
r = requests.post(
@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer):
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> TracerSession:
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> TracerSession:
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")
def _get_tenant_id() -> Optional[str]:
"""Get the tenant ID for the LangChain API."""
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
if tenant_id:
return tenant_id
endpoint = _get_endpoint()
headers = _get_headers()
response = requests.get(endpoint + "/tenants", headers=headers)
response.raise_for_status()
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint}")
return tenants[0]["id"]
class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@staticmethod
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
self.tenant_id = _get_tenant_id()
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = self.load_session("default")
return session
def _get_default_query_params(self) -> Dict[str, Any]:
"""Get the query params for the LangChain API."""
return {"tenant_id": self.tenant_id}
def load_session(self, session_name: str) -> TracerSessionV2:
"""Load a session with the given name from the tracer."""
try:
url = f"{self._endpoint}/sessions"
params = {"tenant_id": self.tenant_id}
if session_name:
params["name"] = session_name
r = requests.get(url, headers=self._headers, params=params)
tracer_session = TracerSessionV2(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
self.session = tracer_session
return tracer_session
def load_default_session(self) -> TracerSessionV2:
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
inputs: Dict[str, Any] = {}
outputs: Optional[Dict[str, Any]] = None
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer):
return Run(
id=run.uuid,
name=run.serialized.get("name"),
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
extra=run.extra or {},
error=run.error,
execution_order=run.execution_order,
serialized=run.serialized,
inputs=inputs,
outputs=outputs,
session_id=run.session_id,
session_id=session.id,
run_type=run_type,
parent_run_id=run.parent_uuid,
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
child_runs=[self._convert_run(child) for child in child_runs],
)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
run_create = self._convert_run(run)
try:
requests.post(
result = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
result.raise_for_status()
except Exception as e:
logging.warning(f"Failed to persist run: {e}")

@ -31,6 +31,24 @@ class TracerSession(TracerSessionBase):
id: int
class TracerSessionV2Base(TracerSessionBase):
"""A creation class for TracerSessionV2."""
tenant_id: UUID
class TracerSessionV2Create(TracerSessionBase):
"""A creation class for TracerSessionV2."""
pass
class TracerSessionV2(TracerSessionV2Base):
"""TracerSession schema for the V2 API."""
id: UUID
class BaseRun(BaseModel):
"""Base class for Run."""
@ -93,9 +111,9 @@ class Run(BaseModel):
serialized: dict
inputs: dict
outputs: Optional[dict]
session_id: int
session_id: UUID
parent_run_id: Optional[UUID]
example_id: Optional[UUID]
reference_example_id: Optional[UUID]
run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list)

@ -2,8 +2,9 @@
from __future__ import annotations
from datetime import datetime
from typing import List, Union
from uuid import uuid4
from typing import List, Tuple, Union
from unittest.mock import Mock, patch
from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
@ -16,7 +17,8 @@ from langchain.callbacks.tracers.base import (
TracerException,
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
@ -27,7 +29,7 @@ def load_session(session_name: str) -> TracerSession:
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
def _persist_session(session: TracerSessionCreate) -> TracerSession:
def _persist_session(session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session."""
return TracerSession(id=TEST_SESSION_ID, **session.dict())
@ -49,7 +51,7 @@ class FakeTracer(BaseTracer):
"""Persist a run."""
self.runs.append(run)
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
@ -473,3 +475,125 @@ def test_tracer_nested_runs_on_error() -> None:
)
assert tracer.runs == [compare_run] * 3
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV2()
return tracer
# Mock a sample TracerSessionV2 object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSessionV2:
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
# Mock a sample LLMRun, ChainRun, and ToolRun objects
@pytest.fixture
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
session_id=1,
prompts=[],
response=LLMResult(generations=[[]]),
serialized={},
extra={},
)
chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
child_execution_order=1,
session_id=1,
serialized={},
inputs={},
outputs=None,
child_llm_runs=[llm_run],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=1,
tool_input="test",
action="{}",
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
return llm_run, chain_run, tool_run
# Test _get_default_query_params method
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
expected = {"tenant_id": "test-tenant-id"}
result = lang_chain_tracer_v2._get_default_query_params()
assert result == expected
# Test load_session method
@patch("langchain.callbacks.tracers.langchain.requests.get")
def test_load_session(
mock_requests_get: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test that load_session method returns a TracerSessionV2 object."""
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
result = lang_chain_tracer_v2.load_session("test-session-name")
mock_requests_get.assert_called_with(
"http://test-endpoint.com/sessions",
headers={"Content-Type": "application/json", "x-api-key": "foo"},
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
)
assert result == sample_tracer_session_v2
def test_convert_run(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2.session = sample_tracer_session_v2
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
assert isinstance(converted_llm_run, Run)
assert isinstance(converted_chain_run, Run)
assert isinstance(converted_tool_run, Run)
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_run(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
mock_requests_post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
assert mock_requests_post.call_count == 3

@ -70,7 +70,7 @@ def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
assert file_contents is None
file_contents = Path(file_path).read_text()
mocked_responses.get(
mocked_responses.get( # type: ignore
urljoin(URL_BASE.format(ref=ref), path),
body=body,
status=200,
@ -86,7 +86,9 @@ def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
path = "chains/path/chain.json"
loader = Mock()
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500)
mocked_responses.get( # type: ignore
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
)
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})

Loading…
Cancel
Save