diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index f6729885..5fa2b3dd 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -6,7 +6,7 @@ import os import warnings from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast from uuid import UUID, uuid4 from langchain.callbacks.base import ( @@ -21,6 +21,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.tracers.base import TracerSession from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2 +from langchain.callbacks.tracers.schemas import TracerSessionV2 from langchain.schema import AgentAction, AgentFinish, LLMResult Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] @@ -28,7 +29,7 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None ) -tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( +tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 "tracing_callback", default=None ) @@ -48,7 +49,7 @@ def tracing_enabled( ) -> Generator[TracerSession, None, None]: """Get Tracer in a context manager.""" cb = LangChainTracer() - session = cb.load_session(session_name) + session = cast(TracerSession, cb.load_session(session_name)) tracing_callback_var.set(cb) yield session tracing_callback_var.set(None) @@ -57,7 +58,7 @@ def tracing_enabled( @contextmanager def tracing_v2_enabled( session_name: str = "default", -) -> Generator[TracerSession, None, None]: +) -> Generator[TracerSessionV2, None, None]: """Get the experimental tracer handler in a context manager.""" # Issue a warning that this is experimental warnings.warn( diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index a7d3b322..6d036d32 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -12,7 +12,9 @@ from langchain.callbacks.tracers.schemas import ( LLMRun, ToolRun, TracerSession, + TracerSessionBase, TracerSessionCreate, + TracerSessionV2, ) from langchain.schema import LLMResult @@ -27,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} - self.session: Optional[TracerSession] = None + self.session: Optional[Union[TracerSessionV2, TracerSession]] = None @staticmethod def _add_child_run( @@ -49,22 +51,31 @@ class BaseTracer(BaseCallbackHandler, ABC): """Persist a run.""" @abstractmethod - def _persist_session(self, session: TracerSessionCreate) -> TracerSession: + def _persist_session( + self, session: TracerSessionBase + ) -> Union[TracerSession, TracerSessionV2]: """Persist a tracing session.""" - def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession: + def _get_session_create( + self, name: Optional[str] = None, **kwargs: Any + ) -> TracerSessionBase: + return TracerSessionCreate(name=name, extra=kwargs) + + def new_session( + self, name: Optional[str] = None, **kwargs: Any + ) -> Union[TracerSession, TracerSessionV2]: """NOT thread safe, do not call this method from multiple threads.""" - session_create = TracerSessionCreate(name=name, extra=kwargs) + session_create = self._get_session_create(name=name, **kwargs) session = self._persist_session(session_create) self.session = session return session @abstractmethod - def load_session(self, session_name: str) -> TracerSession: + def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]: """Load a tracing session and set it as the Tracer's session.""" @abstractmethod - def load_default_session(self) -> TracerSession: + def load_default_session(self) -> Union[TracerSession, TracerSessionV2]: """Load the default tracing session and set it as the Tracer's session.""" def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 3d9a6b0d..767d1538 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import ( Run, ToolRun, TracerSession, - TracerSessionCreate, + TracerSessionBase, + TracerSessionV2, + TracerSessionV2Create, ) +def _get_headers() -> Dict[str, Any]: + """Get the headers for the LangChain API.""" + headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + return headers + + +def _get_endpoint() -> str: + return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + + class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - def __init__(self, session_name: str = "default", **kwargs: Any) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the LangChain tracer.""" super().__init__(**kwargs) - self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") - self._headers: Dict[str, Any] = {"Content-Type": "application/json"} - if os.getenv("LANGCHAIN_API_KEY"): - self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") - self.session = self.load_session(session_name) + self._endpoint = _get_endpoint() + self._headers = _get_headers() def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" @@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer): except Exception as e: logging.warning(f"Failed to persist run: {e}") - def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession: + def _persist_session( + self, session_create: TracerSessionBase + ) -> Union[TracerSession, TracerSessionV2]: """Persist a session.""" try: r = requests.post( @@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer): self.session = tracer_session return tracer_session - def load_session(self, session_name: str) -> TracerSession: + def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]: """Load a session with the given name from the tracer.""" return self._load_session(session_name) - def load_default_session(self) -> TracerSession: + def load_default_session(self) -> Union[TracerSession, TracerSessionV2]: """Load the default tracing session and set it as the Tracer's session.""" return self._load_session("default") +def _get_tenant_id() -> Optional[str]: + """Get the tenant ID for the LangChain API.""" + tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID") + if tenant_id: + return tenant_id + endpoint = _get_endpoint() + headers = _get_headers() + response = requests.get(endpoint + "/tenants", headers=headers) + response.raise_for_status() + tenants: List[Dict[str, Any]] = response.json() + if not tenants: + raise ValueError(f"No tenants found for URL {endpoint}") + return tenants[0]["id"] + + class LangChainTracerV2(LangChainTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - @staticmethod - def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run: - """Convert a run to a Run.""" + def __init__(self, **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self._endpoint = _get_endpoint() + self._headers = _get_headers() + self.tenant_id = _get_tenant_id() + + def _get_session_create( + self, name: Optional[str] = None, **kwargs: Any + ) -> TracerSessionBase: + return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id) + + def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2: + """Persist a session.""" + try: + r = requests.post( + f"{self._endpoint}/sessions", + data=session_create.json(), + headers=self._headers, + ) + session = TracerSessionV2(id=r.json()["id"], **session_create.dict()) + except Exception as e: + logging.warning(f"Failed to create session, using default session: {e}") + session = self.load_session("default") + return session + + def _get_default_query_params(self) -> Dict[str, Any]: + """Get the query params for the LangChain API.""" + return {"tenant_id": self.tenant_id} + + def load_session(self, session_name: str) -> TracerSessionV2: + """Load a session with the given name from the tracer.""" + try: + url = f"{self._endpoint}/sessions" + params = {"tenant_id": self.tenant_id} + if session_name: + params["name"] = session_name + r = requests.get(url, headers=self._headers, params=params) + tracer_session = TracerSessionV2(**r.json()[0]) + except Exception as e: + session_type = "default" if not session_name else session_name + logging.warning( + f"Failed to load {session_type} session, using empty session: {e}" + ) + tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id) + + self.session = tracer_session + return tracer_session + def load_default_session(self) -> TracerSessionV2: + """Load the default tracing session and set it as the Tracer's session.""" + return self.load_session("default") + + def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run: + """Convert a run to a Run.""" + session = self.session or self.load_default_session() inputs: Dict[str, Any] = {} outputs: Optional[Dict[str, Any]] = None child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] @@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer): return Run( id=run.uuid, - name=run.serialized.get("name"), + name=run.serialized.get("name", f"{run_type}-{run.uuid}"), start_time=run.start_time, end_time=run.end_time, - extra=run.extra, + extra=run.extra or {}, error=run.error, execution_order=run.execution_order, serialized=run.serialized, inputs=inputs, outputs=outputs, - session_id=run.session_id, + session_id=session.id, run_type=run_type, parent_run_id=run.parent_uuid, - child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs], + child_runs=[self._convert_run(child) for child in child_runs], ) def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" run_create = self._convert_run(run) - try: - requests.post( + result = requests.post( f"{self._endpoint}/runs", data=run_create.json(), headers=self._headers, ) + result.raise_for_status() except Exception as e: logging.warning(f"Failed to persist run: {e}") diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index e863bd04..016ef93d 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -31,6 +31,24 @@ class TracerSession(TracerSessionBase): id: int +class TracerSessionV2Base(TracerSessionBase): + """A creation class for TracerSessionV2.""" + + tenant_id: UUID + + +class TracerSessionV2Create(TracerSessionBase): + """A creation class for TracerSessionV2.""" + + pass + + +class TracerSessionV2(TracerSessionV2Base): + """TracerSession schema for the V2 API.""" + + id: UUID + + class BaseRun(BaseModel): """Base class for Run.""" @@ -93,9 +111,9 @@ class Run(BaseModel): serialized: dict inputs: dict outputs: Optional[dict] - session_id: int + session_id: UUID parent_run_id: Optional[UUID] - example_id: Optional[UUID] + reference_example_id: Optional[UUID] run_type: RunTypeEnum child_runs: List[Run] = Field(default_factory=list) diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 66f0c387..49d83f08 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -2,8 +2,9 @@ from __future__ import annotations from datetime import datetime -from typing import List, Union -from uuid import uuid4 +from typing import List, Tuple, Union +from unittest.mock import Mock, patch +from uuid import UUID, uuid4 import pytest from freezegun import freeze_time @@ -16,7 +17,8 @@ from langchain.callbacks.tracers.base import ( TracerException, TracerSession, ) -from langchain.callbacks.tracers.schemas import TracerSessionCreate +from langchain.callbacks.tracers.langchain import LangChainTracerV2 +from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2 from langchain.schema import LLMResult TEST_SESSION_ID = 2023 @@ -27,7 +29,7 @@ def load_session(session_name: str) -> TracerSession: return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) -def _persist_session(session: TracerSessionCreate) -> TracerSession: +def _persist_session(session: TracerSessionBase) -> TracerSession: """Persist a tracing session.""" return TracerSession(id=TEST_SESSION_ID, **session.dict()) @@ -49,7 +51,7 @@ class FakeTracer(BaseTracer): """Persist a run.""" self.runs.append(run) - def _persist_session(self, session: TracerSessionCreate) -> TracerSession: + def _persist_session(self, session: TracerSessionBase) -> TracerSession: """Persist a tracing session.""" return _persist_session(session) @@ -473,3 +475,125 @@ def test_tracer_nested_runs_on_error() -> None: ) assert tracer.runs == [compare_run] * 3 + + +_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a") +_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad") + + +@pytest.fixture +def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2: + monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") + monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") + tracer = LangChainTracerV2() + return tracer + + +# Mock a sample TracerSessionV2 object +@pytest.fixture +def sample_tracer_session_v2() -> TracerSessionV2: + return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID) + + +# Mock a sample LLMRun, ChainRun, and ToolRun objects +@pytest.fixture +def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]: + llm_run = LLMRun( + uuid="57a08cc4-73d2-4236-8370-549099d07fad", + name="llm_run", + execution_order=1, + child_execution_order=1, + session_id=1, + prompts=[], + response=LLMResult(generations=[[]]), + serialized={}, + extra={}, + ) + chain_run = ChainRun( + uuid="57a08cc4-73d2-4236-8371-549099d07fad", + name="chain_run", + execution_order=1, + child_execution_order=1, + session_id=1, + serialized={}, + inputs={}, + outputs=None, + child_llm_runs=[llm_run], + child_chain_runs=[], + child_tool_runs=[], + extra={}, + ) + tool_run = ToolRun( + uuid="57a08cc4-73d2-4236-8372-549099d07fad", + name="tool_run", + execution_order=1, + child_execution_order=1, + session_id=1, + tool_input="test", + action="{}", + serialized={}, + child_llm_runs=[], + child_chain_runs=[], + child_tool_runs=[], + extra={}, + ) + return llm_run, chain_run, tool_run + + +# Test _get_default_query_params method +def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None: + expected = {"tenant_id": "test-tenant-id"} + result = lang_chain_tracer_v2._get_default_query_params() + assert result == expected + + +# Test load_session method +@patch("langchain.callbacks.tracers.langchain.requests.get") +def test_load_session( + mock_requests_get: Mock, + lang_chain_tracer_v2: LangChainTracerV2, + sample_tracer_session_v2: TracerSessionV2, +) -> None: + """Test that load_session method returns a TracerSessionV2 object.""" + mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()] + result = lang_chain_tracer_v2.load_session("test-session-name") + mock_requests_get.assert_called_with( + "http://test-endpoint.com/sessions", + headers={"Content-Type": "application/json", "x-api-key": "foo"}, + params={"tenant_id": "test-tenant-id", "name": "test-session-name"}, + ) + assert result == sample_tracer_session_v2 + + +def test_convert_run( + lang_chain_tracer_v2: LangChainTracerV2, + sample_tracer_session_v2: TracerSessionV2, + sample_runs: Tuple[LLMRun, ChainRun, ToolRun], +) -> None: + llm_run, chain_run, tool_run = sample_runs + lang_chain_tracer_v2.session = sample_tracer_session_v2 + converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run) + converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run) + converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run) + + assert isinstance(converted_llm_run, Run) + assert isinstance(converted_chain_run, Run) + assert isinstance(converted_tool_run, Run) + + +@patch("langchain.callbacks.tracers.langchain.requests.post") +def test_persist_run( + mock_requests_post: Mock, + lang_chain_tracer_v2: LangChainTracerV2, + sample_tracer_session_v2: TracerSessionV2, + sample_runs: Tuple[LLMRun, ChainRun, ToolRun], +) -> None: + mock_requests_post.return_value.raise_for_status.return_value = None + lang_chain_tracer_v2.session = sample_tracer_session_v2 + llm_run, chain_run, tool_run = sample_runs + lang_chain_tracer_v2._persist_run(llm_run) + lang_chain_tracer_v2._persist_run(chain_run) + lang_chain_tracer_v2._persist_run(tool_run) + + assert mock_requests_post.call_count == 3 diff --git a/tests/unit_tests/utilities/test_loading.py b/tests/unit_tests/utilities/test_loading.py index 380297f7..468cc659 100644 --- a/tests/unit_tests/utilities/test_loading.py +++ b/tests/unit_tests/utilities/test_loading.py @@ -70,7 +70,7 @@ def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None: assert file_contents is None file_contents = Path(file_path).read_text() - mocked_responses.get( + mocked_responses.get( # type: ignore urljoin(URL_BASE.format(ref=ref), path), body=body, status=200, @@ -86,7 +86,9 @@ def test_failed_request(mocked_responses: responses.RequestsMock) -> None: path = "chains/path/chain.json" loader = Mock() - mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500) + mocked_responses.get( # type: ignore + urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500 + ) with pytest.raises(ValueError, match=re.compile("Could not find file at .*")): try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})