[Breaking] Refactor Base Tracer(#4549)

### Refactor the BaseTracer
- Remove the 'session' abstraction from the BaseTracer
- Rename 'RunV2' object(s) to be called 'Run' objects (Rename previous
Run objects to be RunV1 objects)
- Ditto for sessions: TracerSession*V2 -> TracerSession*
- Remove now deprecated conversion from v1 run objects to v2 run objects
in LangChainTracerV2
- Add conversion from v2 run objects to v1 run objects in V1 tracer
This commit is contained in:
Zander Chase 2023-05-13 17:23:56 +00:00 committed by GitHub
parent 1e322ffc1c
commit 928cdd57a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2735 additions and 2170 deletions

View File

@ -20,9 +20,9 @@ from langchain.callbacks.base import (
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.schema import (
AgentAction,
AgentFinish,
@ -37,11 +37,13 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
tracing_callback_var: ContextVar[
Optional[LangChainTracerV1]
] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
tracing_v2_callback_var: ContextVar[
Optional[LangChainTracerV2]
Optional[LangChainTracer]
] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
@ -59,10 +61,10 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
) -> Generator[TracerSessionV1, None, None]:
"""Get Tracer in a context manager."""
cb = LangChainTracer()
session = cast(TracerSession, cb.load_session(session_name))
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
@ -70,9 +72,12 @@ def tracing_enabled(
@contextmanager
def tracing_v2_enabled(
session_name: str = "default",
session_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
) -> Generator[TracerSessionV2, None, None]:
tenant_id: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
) -> Generator[TracerSession, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
warnings.warn(
@ -81,11 +86,16 @@ def tracing_v2_enabled(
)
if isinstance(example_id, str):
example_id = UUID(example_id)
cb = LangChainTracerV2(example_id=example_id)
session = cast(TracerSessionV2, cb.new_session(session_name))
tracing_callback_var.set(cb)
cb = LangChainTracer(
tenant_id=tenant_id,
session_name=session_name,
example_id=example_id,
session_extra=session_extra,
)
session = cb.ensure_session()
tracing_v2_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
tracing_v2_callback_var.set(None)
def _handle_event(
@ -829,32 +839,35 @@ def _configure(
tracer_session = os.environ.get("LANGCHAIN_SESSION")
if tracer_session is None:
tracer_session = "default"
if verbose or tracing_enabled_ or open_ai is not None:
if verbose or tracing_enabled_ or tracing_v2_enabled_ or open_ai is not None:
if verbose and not any(
isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(StdOutCallbackHandler(), False)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracer)
isinstance(handler, LangChainTracerV1)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracer()
handler = LangChainTracerV1()
handler.load_session(tracer_session)
callback_manager.add_handler(handler, True)
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracerV2)
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer_v2:
callback_manager.add_handler(tracer_v2, True)
else:
handler = LangChainTracerV2()
handler.load_session(tracer_session)
callback_manager.add_handler(handler, True)
try:
handler = LangChainTracer(session_name=tracer_session)
handler.ensure_session()
callback_manager.add_handler(handler, True)
except Exception as e:
logger.debug("Unable to load requested LangChainTracer", e)
if open_ai is not None and not any(
isinstance(handler, OpenAICallbackHandler)
for handler in callback_manager.handlers

View File

@ -1,5 +1,6 @@
"""Tracers that record execution of LangChain runs."""
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
__all__ = ["LangChainTracer"]
__all__ = ["LangChainTracer", "LangChainTracerV1"]

View File

@ -7,15 +7,7 @@ from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
ToolRun,
TracerSession,
TracerSessionBase,
TracerSessionCreate,
TracerSessionV2,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.schema import LLMResult
@ -28,89 +20,45 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[Union[TracerSession, TracerSessionV2]] = None
self.run_map: Dict[str, Run] = {}
@staticmethod
def _add_child_run(
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
parent_run: Run,
child_run: Run,
) -> None:
"""Add child run to a chain run or tool run."""
if isinstance(child_run, LLMRun):
parent_run.child_llm_runs.append(child_run)
elif isinstance(child_run, ChainRun):
parent_run.child_chain_runs.append(child_run)
elif isinstance(child_run, ToolRun):
parent_run.child_tool_runs.append(child_run)
else:
raise TracerException(f"Invalid run type: {type(child_run)}")
parent_run.child_runs.append(child_run)
@abstractmethod
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
@abstractmethod
def _persist_session(
self, session: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a tracing session."""
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionCreate(name=name, extra=kwargs)
def new_session(
self, name: Optional[str] = None, **kwargs: Any
) -> Union[TracerSession, TracerSessionV2]:
"""NOT thread safe, do not call this method from multiple threads."""
session_create = self._get_session_create(name=name, **kwargs)
session = self._persist_session(session_create)
self.session = session
return session
@abstractmethod
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a tracing session and set it as the Tracer's session."""
@abstractmethod
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
def _start_trace(self, run: Run) -> None:
"""Start a trace for a run."""
if run.parent_uuid:
parent_run = self.run_map[run.parent_uuid]
if run.parent_run_id:
parent_run = self.run_map[str(run.parent_run_id)]
if parent_run:
if isinstance(parent_run, LLMRun):
raise TracerException(
"Cannot add child run to an LLM run. "
"LLM runs are not allowed to have children."
)
self._add_child_run(parent_run, run)
else:
raise TracerException(
f"Parent run with UUID {run.parent_uuid} not found."
f"Parent run with UUID {run.parent_run_id} not found."
)
self.run_map[str(run.id)] = run
self.run_map[run.uuid] = run
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
def _end_trace(self, run: Run) -> None:
"""End a trace for a run."""
if not run.parent_uuid:
if not run.parent_run_id:
self._persist_run(run)
else:
parent_run = self.run_map.get(run.parent_uuid)
parent_run = self.run_map.get(str(run.parent_run_id))
if parent_run is None:
raise TracerException(
f"Parent run with UUID {run.parent_uuid} not found."
f"Parent run with UUID {run.parent_run_id} not found."
)
if isinstance(parent_run, LLMRun):
raise TracerException("LLM Runs are not allowed to have children. ")
if run.child_execution_order > parent_run.child_execution_order:
parent_run.child_execution_order = run.child_execution_order
self.run_map.pop(run.uuid)
self.run_map.pop(str(run.id))
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
"""Get the execution order for a run."""
@ -121,9 +69,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
if parent_run is None:
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
if isinstance(parent_run, LLMRun):
raise TracerException("LLM Runs are not allowed to have children. ")
return parent_run.child_execution_order + 1
def on_llm_start(
@ -136,25 +81,22 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
llm_run = LLMRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
llm_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
prompts=prompts,
inputs={"prompts": prompts},
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
run_type=RunTypeEnum.llm,
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for an LLM run."""
@ -163,11 +105,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
llm_run.response = response
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
raise TracerException("No LLM Run found to be traced")
llm_run.outputs = response.dict()
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)
self._on_llm_end(llm_run)
def on_llm_error(
self,
@ -182,12 +125,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
raise TracerException("No LLM Run found to be traced")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)
self._on_chain_error(llm_run)
def on_chain_start(
self,
@ -199,16 +142,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
chain_run = ChainRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
chain_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs=inputs,
extra=kwargs,
@ -216,23 +155,25 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
session_id=self.session.id,
run_type=RunTypeEnum.chain,
)
self._start_trace(chain_run)
self._on_chain_start(chain_run)
def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
) -> None:
"""End a trace for a chain run."""
run_id_ = str(run_id)
chain_run = self.run_map.get(run_id_)
if chain_run is None or not isinstance(chain_run, ChainRun):
raise TracerException("No ChainRun found to be traced")
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
raise TracerException("No chain Run found to be traced")
chain_run.outputs = outputs
chain_run.end_time = datetime.utcnow()
self._end_trace(chain_run)
self._on_chain_end(chain_run)
def on_chain_error(
self,
@ -242,15 +183,16 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> None:
"""Handle an error for a chain run."""
run_id_ = str(run_id)
chain_run = self.run_map.get(run_id_)
if chain_run is None or not isinstance(chain_run, ChainRun):
raise TracerException("No ChainRun found to be traced")
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
raise TracerException("No chain Run found to be traced")
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
self._end_trace(chain_run)
self._on_chain_error(chain_run)
def on_tool_start(
self,
@ -262,40 +204,36 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
tool_run = ToolRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
tool_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
# TODO: this is duplicate info as above, not needed.
action=str(serialized),
tool_input=input_str,
inputs={"input": input_str},
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
session_id=self.session.id,
run_type=RunTypeEnum.tool,
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run."""
run_id_ = str(run_id)
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
raise TracerException("No tool Run found to be traced")
tool_run = self.run_map.get(run_id_)
if tool_run is None or not isinstance(tool_run, ToolRun):
raise TracerException("No ToolRun found to be traced")
tool_run.output = output
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.utcnow()
self._end_trace(tool_run)
self._on_tool_end(tool_run)
def on_tool_error(
self,
@ -305,15 +243,16 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> None:
"""Handle an error for a tool run."""
run_id_ = str(run_id)
tool_run = self.run_map.get(run_id_)
if tool_run is None or not isinstance(tool_run, ToolRun):
raise TracerException("No ToolRun found to be traced")
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
raise TracerException("No tool Run found to be traced")
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
self._end_trace(tool_run)
self._on_tool_error(tool_run)
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
@ -322,3 +261,33 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __copy__(self) -> BaseTracer:
"""Copy the tracer."""
return self
def _on_llm_start(self, run: Run) -> None:
"""Process the LLM Run upon start."""
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
def _on_chat_model_start(self, run: Run) -> None:
"""Process the Chat Model Run upon start."""

View File

@ -4,27 +4,24 @@ from __future__ import annotations
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4
from typing import Any, Dict, List, Optional
from uuid import UUID
import requests
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
Run,
RunCreate,
ToolRun,
RunTypeEnum,
TracerSession,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
TracerSessionCreate,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text
def _get_headers() -> Dict[str, Any]:
def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
@ -32,168 +29,47 @@ def _get_headers() -> Dict[str, Any]:
return headers
def _get_endpoint() -> str:
def get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
def _get_tenant_id(
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
) -> str:
"""Get the tenant ID for the LangChain API."""
tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID")
if tenant_id_:
return tenant_id_
endpoint_ = endpoint or get_endpoint()
headers_ = headers or get_headers()
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint_}")
return tenants[0]["id"]
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
tenant_id: Optional[str] = None,
example_id: Optional[UUID] = None,
session_name: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"
try:
response = requests.post(
endpoint,
data=run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
def _persist_session(
self, session_create: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSession(id=r.json()["id"], **session_create.dict())
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = TracerSession(id=1, **session_create.dict())
return session
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
"""Load a session from the tracer."""
try:
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSession(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSession(id=1)
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")
def _get_tenant_id() -> Optional[str]:
"""Get the tenant ID for the LangChain API."""
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
if tenant_id:
return tenant_id
endpoint = _get_endpoint()
headers = _get_headers()
response = requests.get(endpoint + "/tenants", headers=headers)
raise_for_status_with_text(response)
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint}")
return tenants[0]["id"]
class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
self.tenant_id = _get_tenant_id()
self.session: Optional[TracerSession] = None
self._endpoint = get_endpoint()
self._headers = get_headers()
self.tenant_id = tenant_id
self.example_id = example_id
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
"""Persist a session."""
session: Optional[TracerSessionV2] = None
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
raise_for_status_with_text(r)
creation_args = session_create.dict()
if "id" in creation_args:
del creation_args["id"]
return TracerSessionV2(id=r.json()["id"], **creation_args)
except Exception as e:
if session_create.name is not None:
try:
return self.load_session(session_create.name)
except Exception:
pass
logging.warning(
f"Failed to create session {session_create.name},"
f" using empty session: {e}"
)
session = TracerSessionV2(id=uuid4(), **session_create.dict())
return session
def _get_default_query_params(self) -> Dict[str, Any]:
"""Get the query params for the LangChain API."""
return {"tenant_id": self.tenant_id}
def load_session(self, session_name: str) -> TracerSessionV2:
"""Load a session with the given name from the tracer."""
try:
url = f"{self._endpoint}/sessions"
params = {"tenant_id": self.tenant_id}
if session_name:
params["name"] = session_name
r = requests.get(url, headers=self._headers, params=params)
raise_for_status_with_text(r)
tracer_session = TracerSessionV2(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id)
self.session = tracer_session
return tracer_session
def load_default_session(self) -> TracerSessionV2:
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
self.session_extra = session_extra
def on_chat_model_start(
self,
@ -205,81 +81,56 @@ class LangChainTracerV2(LangChainTracer):
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
llm_run = LLMRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
chat_model_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
prompts=[],
extra={**kwargs, "messages": messages},
inputs={"messages": messages_to_dict(batch) for batch in messages},
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
run_type=RunTypeEnum.llm,
)
self._start_trace(llm_run)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
inputs: Dict[str, Any] = {}
outputs: Optional[Dict[str, Any]] = None
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
if isinstance(run, LLMRun):
run_type = "llm"
if run.extra is not None and "messages" in run.extra:
messages: List[List[BaseMessage]] = run.extra.pop("messages")
converted_messages = [messages_to_dict(batch) for batch in messages]
inputs = {"messages": converted_messages}
else:
inputs = {"prompts": run.prompts}
outputs = run.response.dict() if run.response else {}
child_runs = []
elif isinstance(run, ChainRun):
run_type = "chain"
inputs = run.inputs
outputs = run.outputs
child_runs = [
*run.child_llm_runs,
*run.child_chain_runs,
*run.child_tool_runs,
]
else:
run_type = "tool"
inputs = {"input": run.tool_input}
outputs = {"output": run.output} if run.output else {}
child_runs = [
*run.child_llm_runs,
*run.child_chain_runs,
*run.child_tool_runs,
]
return RunCreate(
id=run.uuid,
name=run.serialized.get("name"),
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra or {},
error=run.error,
execution_order=run.execution_order,
serialized=run.serialized,
inputs=inputs,
outputs=outputs,
session_id=session.id,
run_type=run_type,
child_runs=[self._convert_run(child) for child in child_runs],
def ensure_tenant_id(self) -> str:
"""Load or use the tenant ID."""
tenant_id = self.tenant_id or _get_tenant_id(
self.tenant_id, self._endpoint, self._headers
)
self.tenant_id = tenant_id
return tenant_id
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
def ensure_session(self) -> TracerSession:
"""Upsert a session."""
if self.session is not None:
return self.session
tenant_id = self.ensure_tenant_id()
url = f"{self._endpoint}/sessions?upsert=true"
session_create = TracerSessionCreate(
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
)
r = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
raise_for_status_with_text(r)
self.session = TracerSession(**r.json())
return self.session
def _persist_run_nested(self, run: Run) -> None:
"""Persist a run."""
run_create = self._convert_run(run)
run_create.reference_example_id = self.example_id
session = self.ensure_session()
child_runs = run.child_runs
run_dict = run.dict()
del run_dict["child_runs"]
run_create = RunCreate(**run_dict, session_id=session.id)
try:
response = requests.post(
f"{self._endpoint}/runs",
@ -289,3 +140,12 @@ class LangChainTracerV2(LangChainTracer):
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
for child_run in child_runs:
child_run.parent_run_id = run.id
self._persist_run_nested(child_run)
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
self._persist_run_nested(run)

View File

@ -0,0 +1,171 @@
from __future__ import annotations
import logging
from typing import Any, Optional, Union
import requests
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.langchain import get_endpoint, get_headers
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
Run,
ToolRun,
TracerSession,
TracerSessionV1,
TracerSessionV1Base,
)
from langchain.schema import get_buffer_string
from langchain.utils import raise_for_status_with_text
class LangChainTracerV1(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.session: Optional[TracerSessionV1] = None
self._endpoint = get_endpoint()
self._headers = get_headers()
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
session = self.session or self.load_default_session()
if not isinstance(session, TracerSessionV1):
raise ValueError(
"LangChainTracerV1 is not compatible with"
f" session of type {type(session)}"
)
if run.run_type == "llm":
if "prompts" in run.inputs:
prompts = run.inputs["prompts"]
elif "messages" in run.inputs:
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
error=run.error,
extra=run.extra,
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
)
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),
output=None if run.outputs is None else run.outputs.get("output"),
error=run.error,
extra=run.extra,
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
)
raise ValueError(f"Unknown run type: {run.run_type}")
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
v1_run = self._convert_to_v1_run(run)
else:
v1_run = run
if isinstance(v1_run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(v1_run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"
try:
response = requests.post(
endpoint,
data=v1_run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
def _persist_session(
self, session_create: TracerSessionV1Base
) -> Union[TracerSessionV1, TracerSession]:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = TracerSessionV1(id=1, **session_create.dict())
return session
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
"""Load a session from the tracer."""
try:
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSessionV1(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV1(id=1)
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")

View File

@ -6,47 +6,45 @@ from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, root_validator
from langchain.schema import LLMResult
class TracerSessionBase(BaseModel):
"""Base class for TracerSession."""
class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
class TracerSessionCreate(TracerSessionBase):
"""Create class for TracerSession."""
class TracerSessionV1Create(TracerSessionV1Base):
"""Create class for TracerSessionV1."""
pass
class TracerSession(TracerSessionBase):
"""TracerSession schema."""
class TracerSessionV1(TracerSessionV1Base):
"""TracerSessionV1 schema."""
id: int
class TracerSessionV2Base(TracerSessionBase):
"""A creation class for TracerSessionV2."""
class TracerSessionBase(TracerSessionV1Base):
"""A creation class for TracerSession."""
tenant_id: UUID
class TracerSessionV2Create(TracerSessionV2Base):
"""A creation class for TracerSessionV2."""
class TracerSessionCreate(TracerSessionBase):
"""A creation class for TracerSession."""
id: Optional[UUID]
pass
class TracerSessionV2(TracerSessionV2Base):
"""TracerSession schema for the V2 API."""
class TracerSession(TracerSessionBase):
"""TracerSessionV1 schema for the V2 API."""
id: UUID
@ -111,26 +109,32 @@ class RunBase(BaseModel):
extra: dict
error: Optional[str]
execution_order: int
child_execution_order: int
serialized: dict
inputs: dict
outputs: Optional[dict]
session_id: UUID
reference_example_id: Optional[UUID]
run_type: RunTypeEnum
parent_run_id: Optional[UUID]
class RunCreate(RunBase):
"""Schema to create a run in the DB."""
name: Optional[str]
child_runs: List[RunCreate] = Field(default_factory=list)
class Run(RunBase):
"""Run schema when loading from the DB."""
name: str
child_runs: List[Run] = Field(default_factory=list)
@root_validator(pre=True)
def assign_name(cls, values: dict) -> dict:
"""Assign name to the run."""
if "name" not in values:
values["name"] = values["serialized"]["name"]
return values
class RunCreate(RunBase):
name: str
session_id: UUID
ChainRun.update_forward_refs()

View File

@ -27,7 +27,7 @@ from requests import Response
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
@ -308,7 +308,7 @@ class LangChainPlusClient(BaseSettings):
async def _arun_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
langchain_tracer: LangChainTracerV2,
langchain_tracer: LangChainTracer,
) -> Union[LLMResult, ChatResult]:
if isinstance(llm, BaseLLM):
if "prompt" not in inputs:
@ -328,7 +328,7 @@ class LangChainPlusClient(BaseSettings):
@staticmethod
async def _arun_llm_or_chain(
example: Example,
langchain_tracer: LangChainTracerV2,
langchain_tracer: LangChainTracer,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
@ -358,8 +358,8 @@ class LangChainPlusClient(BaseSettings):
@staticmethod
async def _gather_with_concurrency(
n: int,
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracerV2, Dict]]],
*async_funcs: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]],
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]],
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
) -> List[Any]:
"""
Run coroutines with a concurrency limit.
@ -376,7 +376,7 @@ class LangChainPlusClient(BaseSettings):
tracer, job_state = await initializer()
async def run_coroutine_with_semaphore(
async_func: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]]
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
) -> Any:
async with semaphore:
return await async_func(tracer, job_state)
@ -387,7 +387,7 @@ class LangChainPlusClient(BaseSettings):
async def _tracer_initializer(
self, session_name: str
) -> Tuple[LangChainTracerV2, dict]:
) -> Tuple[LangChainTracer, dict]:
"""
Initialize a tracer to share across tasks.
@ -395,11 +395,11 @@ class LangChainPlusClient(BaseSettings):
session_name: The session name for the tracer.
Returns:
A LangChainTracerV2 instance with an active session.
A LangChainTracer instance with an active session.
"""
job_state = {"num_processed": 0}
with tracing_v2_enabled(session_name=session_name) as session:
tracer = LangChainTracerV2()
tracer = LangChainTracer()
tracer.session = session
return tracer, job_state
@ -440,7 +440,7 @@ class LangChainPlusClient(BaseSettings):
results: Dict[str, List[Any]] = {}
async def process_example(
example: Example, tracer: LangChainTracerV2, job_state: dict
example: Example, tracer: LangChainTracer, job_state: dict
) -> None:
"""Process a single example."""
result = await LangChainPlusClient._arun_llm_or_chain(
@ -469,7 +469,7 @@ class LangChainPlusClient(BaseSettings):
def run_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
langchain_tracer: LangChainTracerV2,
langchain_tracer: LangChainTracer,
) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example."""
if isinstance(llm, BaseLLM):
@ -492,7 +492,7 @@ class LangChainPlusClient(BaseSettings):
@staticmethod
def run_llm_or_chain(
example: Example,
langchain_tracer: LangChainTracerV2,
langchain_tracer: LangChainTracer,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
@ -551,7 +551,7 @@ class LangChainPlusClient(BaseSettings):
examples = list(self.list_examples(dataset_id=str(dataset.id)))
results: Dict[str, Any] = {}
with tracing_v2_enabled(session_name=session_name) as session:
tracer = LangChainTracerV2()
tracer = LangChainTracer()
tracer.session = session
for i, example in enumerate(examples):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,464 @@
"""Test Tracer classes."""
from __future__ import annotations
from datetime import datetime
from typing import List
from uuid import uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(run)
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = Run(
id=uuid,
parent_run_id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
name="chat_model",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
inputs=dict(prompts=[""]),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = uuid4()
compare_run = Run(
id=uuid,
name="llm",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs={},
error=None,
run_type="chain",
)
tracer = FakeTracer()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
inputs={"input": "test"},
outputs={"output": "test"},
error=None,
run_type="tool",
)
tracer = FakeTracer()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
for _ in range(10):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = Run(
id=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=4,
serialized={"name": "chain"},
inputs={},
outputs={},
run_type="chain",
child_runs=[
Run(
id=tool_uuid,
parent_run_id=chain_uuid,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=3,
serialized={"name": "tool"},
inputs=dict(input="test"),
outputs=dict(output="test"),
error=None,
run_type="tool",
child_runs=[
Run(
id=str(llm_uuid1),
parent_run_id=str(tool_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
)
],
),
Run(
id=str(llm_uuid2),
parent_run_id=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
),
],
)
assert tracer.runs[0] == compare_run
assert tracer.runs == [compare_run] * 10
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
inputs=dict(prompts=[]),
outputs=None,
error=repr(exception),
run_type="llm",
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs=None,
error=repr(exception),
run_type="chain",
)
tracer = FakeTracer()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
inputs=dict(input="test"),
outputs=None,
action="{'name': 'tool'}",
error=repr(exception),
run_type="tool",
)
tracer = FakeTracer()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = Run(
id=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={"name": "chain"},
error=repr(exception),
inputs={},
outputs=None,
run_type="chain",
child_runs=[
Run(
id=str(llm_uuid1),
parent_run_id=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
run_type="llm",
),
Run(
id=str(llm_uuid2),
parent_run_id=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
run_type="llm",
),
Run(
id=str(tool_uuid),
parent_run_id=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={"name": "tool"},
error=repr(exception),
inputs=dict(input="test"),
outputs=None,
action="{'name': 'tool'}",
child_runs=[
Run(
id=str(llm_uuid3),
parent_run_id=str(tool_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
error=repr(exception),
inputs=dict(prompts=[]),
outputs=None,
run_type="llm",
)
],
run_type="tool",
),
],
)
assert tracer.runs == [compare_run] * 3

View File

@ -0,0 +1,676 @@
"""Test Tracer classes."""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional, Union
from uuid import uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.langchain_v1 import (
ChainRun,
LangChainTracerV1,
LLMRun,
ToolRun,
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name=session_name, start_time=datetime.utcnow()
)
def new_session(name: Optional[str] = None) -> TracerSessionV1:
"""Create a new tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name=name or "default", start_time=datetime.utcnow()
)
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1:
"""Persist a tracing session."""
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID})
def load_default_session() -> TracerSessionV1:
"""Load a tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name="default", start_time=datetime.utcnow()
)
@pytest.fixture
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV1()
return tracer
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
with pytest.MonkeyPatch().context() as m:
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
m.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV1()
tracer.load_default_session = load_default_session # type: ignore
run = tracer._convert_to_v1_run(run)
self.runs.append(run)
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1:
"""Persist a tracing session."""
return _persist_session(session)
def new_session(self, name: Optional[str] = None) -> TracerSessionV1:
"""Create a new tracing session."""
return new_session(name)
def load_session(self, session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSessionV1:
"""Load a tracing session."""
return load_default_session()
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
tool_input="test",
output="test",
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
for _ in range(10):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=4,
serialized={"name": "chain"},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=3,
serialized={"name": "tool"},
tool_input="test",
output="test",
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=None,
child_chain_runs=[],
child_tool_runs=[],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(tool_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
)
],
),
],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
),
],
)
assert tracer.runs[0] == compare_run
assert tracer.runs == [compare_run] * 10
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
tool_input="test",
output=None,
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={"name": "chain"},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
],
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={"name": "tool"},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="{'name': 'tool'}",
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid3),
parent_uuid=str(tool_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_chain_runs=[],
child_tool_runs=[],
),
],
)
assert tracer.runs == [compare_run] * 3
@pytest.fixture
def sample_tracer_session_v1() -> TracerSessionV1:
return TracerSessionV1(id=2, name="Sample session")
@freeze_time("2023-01-01")
def test_convert_run(
lang_chain_tracer_v1: LangChainTracerV1,
sample_tracer_session_v1: TracerSessionV1,
) -> None:
"""Test converting a run to a V1 run."""
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=TEST_SESSION_ID,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type=RunTypeEnum.llm,
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
child_execution_order=1,
serialized={},
inputs={},
outputs={},
child_runs=[llm_run],
extra={},
run_type=RunTypeEnum.chain,
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
inputs={"input": "test"},
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
outputs=None,
serialized={},
child_runs=[],
extra={},
run_type=RunTypeEnum.tool,
)
expected_llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=2,
prompts=[],
response=LLMResult(generations=[[]]),
serialized={},
extra={},
)
expected_chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=2,
serialized={},
inputs={},
outputs={},
child_llm_runs=[expected_llm_run],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
expected_tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=2,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
tool_input="test",
action="{}",
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
lang_chain_tracer_v1.session = sample_tracer_session_v1
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run)
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run)
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run)
assert isinstance(converted_llm_run, LLMRun)
assert isinstance(converted_chain_run, ChainRun)
assert isinstance(converted_tool_run, ToolRun)
assert converted_llm_run == expected_llm_run
assert converted_tool_run == expected_tool_run
assert converted_chain_run == expected_chain_run

View File

@ -3,621 +3,90 @@ from __future__ import annotations
import json
from datetime import datetime
from typing import List, Tuple, Union
from unittest.mock import Mock, patch
from typing import Tuple
from unittest.mock import patch
from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import (
BaseTracer,
ChainRun,
LLMRun,
ToolRun,
TracerException,
TracerSession,
)
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import (
RunCreate,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
)
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
def load_session(session_name: str) -> TracerSession:
"""Load a tracing session."""
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
def _persist_session(session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session."""
return TracerSession(id=TEST_SESSION_ID, **session.dict())
def load_default_session() -> TracerSession:
"""Load a tracing session."""
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
self.runs.append(run)
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load a tracing session."""
return load_default_session()
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
tool_input="test",
output="test",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
for _ in range(10):
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
tracer.on_tool_start(
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=4,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=3,
serialized={},
tool_input="test",
output="test",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
child_chain_runs=[],
child_tool_runs=[],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(tool_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
)
],
),
],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=4,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
),
],
)
assert tracer.runs == [compare_run] * 10
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
inputs={},
outputs=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
tool_input="test",
output=None,
action="{}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
],
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="{}",
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid3),
parent_uuid=str(tool_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_chain_runs=[],
child_tool_runs=[],
),
],
)
assert tracer.runs == [compare_run] * 3
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV2()
tracer = LangChainTracer()
return tracer
# Mock a sample TracerSessionV2 object
# Mock a sample TracerSession object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSessionV2:
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
def sample_tracer_session_v2() -> TracerSession:
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
# Mock a sample LLMRun, ChainRun, and ToolRun objects
@freeze_time("2023-01-01")
@pytest.fixture
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
def sample_runs() -> Tuple[Run, Run, Run]:
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=1,
prompts=[],
response=LLMResult(generations=[[]]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type=RunTypeEnum.llm,
)
chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
child_execution_order=1,
session_id=1,
serialized={},
inputs={},
outputs=None,
child_llm_runs=[llm_run],
child_chain_runs=[],
child_tool_runs=[],
outputs={},
child_runs=[llm_run],
extra={},
run_type=RunTypeEnum.chain,
)
tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=1,
tool_input="test",
action="{}",
inputs={"input": "test"},
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
outputs=None,
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
child_runs=[],
extra={},
run_type=RunTypeEnum.tool,
)
return llm_run, chain_run, tool_run
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
expected = {"tenant_id": "test-tenant-id"}
result = lang_chain_tracer_v2._get_default_query_params()
assert result == expected
@patch("langchain.callbacks.tracers.langchain.requests.get")
def test_load_session(
mock_requests_get: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test that load_session method returns a TracerSessionV2 object."""
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
result = lang_chain_tracer_v2.load_session("test-session-name")
mock_requests_get.assert_called_with(
"http://test-endpoint.com/sessions",
headers={"Content-Type": "application/json", "x-api-key": "foo"},
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
)
assert result == sample_tracer_session_v2
def test_convert_run(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2.session = sample_tracer_session_v2
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
assert isinstance(converted_llm_run, RunCreate)
assert isinstance(converted_chain_run, RunCreate)
assert isinstance(converted_tool_run, RunCreate)
def test_persist_run(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test that persist_run method calls requests.post once per method call."""
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
@ -625,25 +94,25 @@ def test_persist_run(
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
for run in sample_runs:
lang_chain_tracer_v2.run_map[str(run.id)] = run
for run in sample_runs:
lang_chain_tracer_v2._end_trace(run)
assert post.call_count == 3
assert get.call_count == 0
def test_persist_run_with_example_id(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test the example ID is assigned only to the parent run and not the children."""
example_id = uuid4()
llm_run, chain_run, tool_run = sample_runs
chain_run.child_tool_runs = [tool_run]
tool_run.child_llm_runs = [llm_run]
chain_run.child_runs = [tool_run]
tool_run.child_runs = [llm_run]
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
@ -652,55 +121,14 @@ def test_persist_run_with_example_id(
lang_chain_tracer_v2.example_id = example_id
lang_chain_tracer_v2._persist_run(chain_run)
assert post.call_count == 1
assert post.call_count == 3
assert get.call_count == 0
posted_data = json.loads(post.call_args[1]["data"])
assert posted_data["id"] == chain_run.uuid
assert posted_data["reference_example_id"] == str(example_id)
def assert_child_run_no_example_id(run: dict) -> None:
assert not run.get("reference_example_id")
for child_run in run.get("child_runs", []):
assert_child_run_no_example_id(child_run)
for child_run in posted_data["child_runs"]:
assert_child_run_no_example_id(child_run)
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
"""Test creating the 'SessionCreate' object."""
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
session_create = lang_chain_tracer_v2._get_session_create(name="test")
assert isinstance(session_create, TracerSessionV2Create)
assert session_create.name == "test"
assert session_create.tenant_id == _TENANT_ID
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_session(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
new_id = str(uuid4())
mock_requests_post.return_value.json.return_value = {"id": new_id}
result = lang_chain_tracer_v2._persist_session(session_create)
assert isinstance(result, TracerSessionV2)
res = sample_tracer_session_v2.dict()
res["id"] = UUID(new_id)
assert result.dict() == res
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
def test_load_default_session(
mock_load_session: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test load_default_session attempts to load with the default name."""
mock_load_session.return_value = sample_tracer_session_v2
result = lang_chain_tracer_v2.load_default_session()
assert result == sample_tracer_session_v2
mock_load_session.assert_called_with("default")
posted_data = [
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
]
assert posted_data[0]["id"] == str(chain_run.id)
assert posted_data[0]["reference_example_id"] == str(example_id)
assert posted_data[1]["id"] == str(tool_run.id)
assert not posted_data[1].get("reference_example_id")
assert posted_data[2]["id"] == str(llm_run.id)
assert not posted_data[2].get("reference_example_id")

View File

@ -8,8 +8,8 @@ from unittest import mock
import pytest
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.chains.base import Chain
from langchain.client.langchain import (
LangChainPlusClient,
@ -196,10 +196,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_load_session(
self: Any, name: str, *args: Any, **kwargs: Any
) -> TracerSessionV2:
return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
@ -208,7 +206,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
), mock.patch.object(
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainTracerV2, "load_session", new=mock_load_session
LangChainTracer, "ensure_session", new=mock_ensure_session
):
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
client = LangChainPlusClient(