Add Tenant ID to V2 Tracer (#4135)

Update the V2 tracer to
- use UUIDs instead of int's
- load a tenant ID and use that when saving sessions
pull/4159/head
Zander Chase 1 year ago committed by GitHub
parent fea639c1fc
commit 6032a051e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@ import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain.callbacks.base import ( from langchain.callbacks.base import (
@ -21,6 +21,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2 from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@ -28,7 +29,7 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None "openai_callback", default=None
) )
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback", default=None "tracing_callback", default=None
) )
@ -48,7 +49,7 @@ def tracing_enabled(
) -> Generator[TracerSession, None, None]: ) -> Generator[TracerSession, None, None]:
"""Get Tracer in a context manager.""" """Get Tracer in a context manager."""
cb = LangChainTracer() cb = LangChainTracer()
session = cb.load_session(session_name) session = cast(TracerSession, cb.load_session(session_name))
tracing_callback_var.set(cb) tracing_callback_var.set(cb)
yield session yield session
tracing_callback_var.set(None) tracing_callback_var.set(None)
@ -57,7 +58,7 @@ def tracing_enabled(
@contextmanager @contextmanager
def tracing_v2_enabled( def tracing_v2_enabled(
session_name: str = "default", session_name: str = "default",
) -> Generator[TracerSession, None, None]: ) -> Generator[TracerSessionV2, None, None]:
"""Get the experimental tracer handler in a context manager.""" """Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental # Issue a warning that this is experimental
warnings.warn( warnings.warn(

@ -12,7 +12,9 @@ from langchain.callbacks.tracers.schemas import (
LLMRun, LLMRun,
ToolRun, ToolRun,
TracerSession, TracerSession,
TracerSessionBase,
TracerSessionCreate, TracerSessionCreate,
TracerSessionV2,
) )
from langchain.schema import LLMResult from langchain.schema import LLMResult
@ -27,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[TracerSession] = None self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
@staticmethod @staticmethod
def _add_child_run( def _add_child_run(
@ -49,22 +51,31 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Persist a run.""" """Persist a run."""
@abstractmethod @abstractmethod
def _persist_session(self, session: TracerSessionCreate) -> TracerSession: def _persist_session(
self, session: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a tracing session.""" """Persist a tracing session."""
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession: def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionCreate(name=name, extra=kwargs)
def new_session(
self, name: Optional[str] = None, **kwargs: Any
) -> Union[TracerSession, TracerSessionV2]:
"""NOT thread safe, do not call this method from multiple threads.""" """NOT thread safe, do not call this method from multiple threads."""
session_create = TracerSessionCreate(name=name, extra=kwargs) session_create = self._get_session_create(name=name, **kwargs)
session = self._persist_session(session_create) session = self._persist_session(session_create)
self.session = session self.session = session
return session return session
@abstractmethod @abstractmethod
def load_session(self, session_name: str) -> TracerSession: def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a tracing session and set it as the Tracer's session.""" """Load a tracing session and set it as the Tracer's session."""
@abstractmethod @abstractmethod
def load_default_session(self) -> TracerSession: def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session.""" """Load the default tracing session and set it as the Tracer's session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:

@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import (
Run, Run,
ToolRun, ToolRun,
TracerSession, TracerSession,
TracerSessionCreate, TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
) )
def _get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers
def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
class LangChainTracer(BaseTracer): class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint.""" """An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, session_name: str = "default", **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer.""" """Initialize the LangChain tracer."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") self._endpoint = _get_endpoint()
self._headers: Dict[str, Any] = {"Content-Type": "application/json"} self._headers = _get_headers()
if os.getenv("LANGCHAIN_API_KEY"):
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
self.session = self.load_session(session_name)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run.""" """Persist a run."""
@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer):
except Exception as e: except Exception as e:
logging.warning(f"Failed to persist run: {e}") logging.warning(f"Failed to persist run: {e}")
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession: def _persist_session(
self, session_create: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a session.""" """Persist a session."""
try: try:
r = requests.post( r = requests.post(
@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer):
self.session = tracer_session self.session = tracer_session
return tracer_session return tracer_session
def load_session(self, session_name: str) -> TracerSession: def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a session with the given name from the tracer.""" """Load a session with the given name from the tracer."""
return self._load_session(session_name) return self._load_session(session_name)
def load_default_session(self) -> TracerSession: def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session.""" """Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default") return self._load_session("default")
def _get_tenant_id() -> Optional[str]:
"""Get the tenant ID for the LangChain API."""
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
if tenant_id:
return tenant_id
endpoint = _get_endpoint()
headers = _get_headers()
response = requests.get(endpoint + "/tenants", headers=headers)
response.raise_for_status()
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint}")
return tenants[0]["id"]
class LangChainTracerV2(LangChainTracer): class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint.""" """An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@staticmethod def __init__(self, **kwargs: Any) -> None:
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run: """Initialize the LangChain tracer."""
"""Convert a run to a Run.""" super().__init__(**kwargs)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
self.tenant_id = _get_tenant_id()
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = self.load_session("default")
return session
def _get_default_query_params(self) -> Dict[str, Any]:
"""Get the query params for the LangChain API."""
return {"tenant_id": self.tenant_id}
def load_session(self, session_name: str) -> TracerSessionV2:
"""Load a session with the given name from the tracer."""
try:
url = f"{self._endpoint}/sessions"
params = {"tenant_id": self.tenant_id}
if session_name:
params["name"] = session_name
r = requests.get(url, headers=self._headers, params=params)
tracer_session = TracerSessionV2(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
self.session = tracer_session
return tracer_session
def load_default_session(self) -> TracerSessionV2:
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
inputs: Dict[str, Any] = {} inputs: Dict[str, Any] = {}
outputs: Optional[Dict[str, Any]] = None outputs: Optional[Dict[str, Any]] = None
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer):
return Run( return Run(
id=run.uuid, id=run.uuid,
name=run.serialized.get("name"), name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
start_time=run.start_time, start_time=run.start_time,
end_time=run.end_time, end_time=run.end_time,
extra=run.extra, extra=run.extra or {},
error=run.error, error=run.error,
execution_order=run.execution_order, execution_order=run.execution_order,
serialized=run.serialized, serialized=run.serialized,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
session_id=run.session_id, session_id=session.id,
run_type=run_type, run_type=run_type,
parent_run_id=run.parent_uuid, parent_run_id=run.parent_uuid,
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs], child_runs=[self._convert_run(child) for child in child_runs],
) )
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run.""" """Persist a run."""
run_create = self._convert_run(run) run_create = self._convert_run(run)
try: try:
requests.post( result = requests.post(
f"{self._endpoint}/runs", f"{self._endpoint}/runs",
data=run_create.json(), data=run_create.json(),
headers=self._headers, headers=self._headers,
) )
result.raise_for_status()
except Exception as e: except Exception as e:
logging.warning(f"Failed to persist run: {e}") logging.warning(f"Failed to persist run: {e}")

@ -31,6 +31,24 @@ class TracerSession(TracerSessionBase):
id: int id: int
class TracerSessionV2Base(TracerSessionBase):
"""A creation class for TracerSessionV2."""
tenant_id: UUID
class TracerSessionV2Create(TracerSessionBase):
"""A creation class for TracerSessionV2."""
pass
class TracerSessionV2(TracerSessionV2Base):
"""TracerSession schema for the V2 API."""
id: UUID
class BaseRun(BaseModel): class BaseRun(BaseModel):
"""Base class for Run.""" """Base class for Run."""
@ -93,9 +111,9 @@ class Run(BaseModel):
serialized: dict serialized: dict
inputs: dict inputs: dict
outputs: Optional[dict] outputs: Optional[dict]
session_id: int session_id: UUID
parent_run_id: Optional[UUID] parent_run_id: Optional[UUID]
example_id: Optional[UUID] reference_example_id: Optional[UUID]
run_type: RunTypeEnum run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list) child_runs: List[Run] = Field(default_factory=list)

@ -2,8 +2,9 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import List, Union from typing import List, Tuple, Union
from uuid import uuid4 from unittest.mock import Mock, patch
from uuid import UUID, uuid4
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
@ -16,7 +17,8 @@ from langchain.callbacks.tracers.base import (
TracerException, TracerException,
TracerSession, TracerSession,
) )
from langchain.callbacks.tracers.schemas import TracerSessionCreate from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
from langchain.schema import LLMResult from langchain.schema import LLMResult
TEST_SESSION_ID = 2023 TEST_SESSION_ID = 2023
@ -27,7 +29,7 @@ def load_session(session_name: str) -> TracerSession:
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
def _persist_session(session: TracerSessionCreate) -> TracerSession: def _persist_session(session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session.""" """Persist a tracing session."""
return TracerSession(id=TEST_SESSION_ID, **session.dict()) return TracerSession(id=TEST_SESSION_ID, **session.dict())
@ -49,7 +51,7 @@ class FakeTracer(BaseTracer):
"""Persist a run.""" """Persist a run."""
self.runs.append(run) self.runs.append(run)
def _persist_session(self, session: TracerSessionCreate) -> TracerSession: def _persist_session(self, session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session.""" """Persist a tracing session."""
return _persist_session(session) return _persist_session(session)
@ -473,3 +475,125 @@ def test_tracer_nested_runs_on_error() -> None:
) )
assert tracer.runs == [compare_run] * 3 assert tracer.runs == [compare_run] * 3
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV2()
return tracer
# Mock a sample TracerSessionV2 object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSessionV2:
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
# Mock a sample LLMRun, ChainRun, and ToolRun objects
@pytest.fixture
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
session_id=1,
prompts=[],
response=LLMResult(generations=[[]]),
serialized={},
extra={},
)
chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
child_execution_order=1,
session_id=1,
serialized={},
inputs={},
outputs=None,
child_llm_runs=[llm_run],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=1,
tool_input="test",
action="{}",
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
return llm_run, chain_run, tool_run
# Test _get_default_query_params method
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
expected = {"tenant_id": "test-tenant-id"}
result = lang_chain_tracer_v2._get_default_query_params()
assert result == expected
# Test load_session method
@patch("langchain.callbacks.tracers.langchain.requests.get")
def test_load_session(
mock_requests_get: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test that load_session method returns a TracerSessionV2 object."""
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
result = lang_chain_tracer_v2.load_session("test-session-name")
mock_requests_get.assert_called_with(
"http://test-endpoint.com/sessions",
headers={"Content-Type": "application/json", "x-api-key": "foo"},
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
)
assert result == sample_tracer_session_v2
def test_convert_run(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2.session = sample_tracer_session_v2
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
assert isinstance(converted_llm_run, Run)
assert isinstance(converted_chain_run, Run)
assert isinstance(converted_tool_run, Run)
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_run(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
mock_requests_post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
assert mock_requests_post.call_count == 3

@ -70,7 +70,7 @@ def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
assert file_contents is None assert file_contents is None
file_contents = Path(file_path).read_text() file_contents = Path(file_path).read_text()
mocked_responses.get( mocked_responses.get( # type: ignore
urljoin(URL_BASE.format(ref=ref), path), urljoin(URL_BASE.format(ref=ref), path),
body=body, body=body,
status=200, status=200,
@ -86,7 +86,9 @@ def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
path = "chains/path/chain.json" path = "chains/path/chain.json"
loader = Mock() loader = Mock()
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500) mocked_responses.get( # type: ignore
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
)
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")): with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"}) try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})

Loading…
Cancel
Save