From 928cdd57a4531e606f7ca7e34c0b96736ffcce49 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Sat, 13 May 2023 17:23:56 +0000 Subject: [PATCH] [Breaking] Refactor Base Tracer(#4549) ### Refactor the BaseTracer - Remove the 'session' abstraction from the BaseTracer - Rename 'RunV2' object(s) to be called 'Run' objects (Rename previous Run objects to be RunV1 objects) - Ditto for sessions: TracerSession*V2 -> TracerSession* - Remove now deprecated conversion from v1 run objects to v2 run objects in LangChainTracerV2 - Add conversion from v2 run objects to v1 run objects in V1 tracer --- langchain/callbacks/manager.py | 55 +- langchain/callbacks/tracers/__init__.py | 3 +- langchain/callbacks/tracers/base.py | 227 +- langchain/callbacks/tracers/langchain.py | 312 +-- langchain/callbacks/tracers/langchain_v1.py | 171 ++ langchain/callbacks/tracers/schemas.py | 50 +- langchain/client/langchain.py | 26 +- .../client/tracing_datasets.ipynb | 2223 ++++++++--------- .../callbacks/tracers/test_base_tracer.py | 464 ++++ .../callbacks/tracers/test_langchain_v1.py | 676 +++++ .../callbacks/tracers/test_tracer.py | 686 +---- tests/unit_tests/client/test_langchain.py | 12 +- 12 files changed, 2735 insertions(+), 2170 deletions(-) create mode 100644 langchain/callbacks/tracers/langchain_v1.py create mode 100644 tests/unit_tests/callbacks/tracers/test_base_tracer.py create mode 100644 tests/unit_tests/callbacks/tracers/test_langchain_v1.py diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 4622f2a32a..f8d692e80d 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -20,9 +20,9 @@ from langchain.callbacks.base import ( ) from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler -from langchain.callbacks.tracers.base import TracerSession -from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2 -from langchain.callbacks.tracers.schemas import TracerSessionV2 +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1 +from langchain.callbacks.tracers.schemas import TracerSession from langchain.schema import ( AgentAction, AgentFinish, @@ -37,11 +37,13 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None ) -tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 +tracing_callback_var: ContextVar[ + Optional[LangChainTracerV1] +] = ContextVar( # noqa: E501 "tracing_callback", default=None ) tracing_v2_callback_var: ContextVar[ - Optional[LangChainTracerV2] + Optional[LangChainTracer] ] = ContextVar( # noqa: E501 "tracing_callback_v2", default=None ) @@ -59,10 +61,10 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: @contextmanager def tracing_enabled( session_name: str = "default", -) -> Generator[TracerSession, None, None]: +) -> Generator[TracerSessionV1, None, None]: """Get Tracer in a context manager.""" - cb = LangChainTracer() - session = cast(TracerSession, cb.load_session(session_name)) + cb = LangChainTracerV1() + session = cast(TracerSessionV1, cb.load_session(session_name)) tracing_callback_var.set(cb) yield session tracing_callback_var.set(None) @@ -70,9 +72,12 @@ def tracing_enabled( @contextmanager def tracing_v2_enabled( - session_name: str = "default", + session_name: Optional[str] = None, + *, example_id: Optional[Union[str, UUID]] = None, -) -> Generator[TracerSessionV2, None, None]: + tenant_id: Optional[str] = None, + session_extra: Optional[Dict[str, Any]] = None, +) -> Generator[TracerSession, None, None]: """Get the experimental tracer handler in a context manager.""" # Issue a warning that this is experimental warnings.warn( @@ -81,11 +86,16 @@ def tracing_v2_enabled( ) if isinstance(example_id, str): example_id = UUID(example_id) - cb = LangChainTracerV2(example_id=example_id) - session = cast(TracerSessionV2, cb.new_session(session_name)) - tracing_callback_var.set(cb) + cb = LangChainTracer( + tenant_id=tenant_id, + session_name=session_name, + example_id=example_id, + session_extra=session_extra, + ) + session = cb.ensure_session() + tracing_v2_callback_var.set(cb) yield session - tracing_callback_var.set(None) + tracing_v2_callback_var.set(None) def _handle_event( @@ -829,32 +839,35 @@ def _configure( tracer_session = os.environ.get("LANGCHAIN_SESSION") if tracer_session is None: tracer_session = "default" - if verbose or tracing_enabled_ or open_ai is not None: + if verbose or tracing_enabled_ or tracing_v2_enabled_ or open_ai is not None: if verbose and not any( isinstance(handler, StdOutCallbackHandler) for handler in callback_manager.handlers ): callback_manager.add_handler(StdOutCallbackHandler(), False) if tracing_enabled_ and not any( - isinstance(handler, LangChainTracer) + isinstance(handler, LangChainTracerV1) for handler in callback_manager.handlers ): if tracer: callback_manager.add_handler(tracer, True) else: - handler = LangChainTracer() + handler = LangChainTracerV1() handler.load_session(tracer_session) callback_manager.add_handler(handler, True) if tracing_v2_enabled_ and not any( - isinstance(handler, LangChainTracerV2) + isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): if tracer_v2: callback_manager.add_handler(tracer_v2, True) else: - handler = LangChainTracerV2() - handler.load_session(tracer_session) - callback_manager.add_handler(handler, True) + try: + handler = LangChainTracer(session_name=tracer_session) + handler.ensure_session() + callback_manager.add_handler(handler, True) + except Exception as e: + logger.debug("Unable to load requested LangChainTracer", e) if open_ai is not None and not any( isinstance(handler, OpenAICallbackHandler) for handler in callback_manager.handlers diff --git a/langchain/callbacks/tracers/__init__.py b/langchain/callbacks/tracers/__init__.py index 5dd69b4871..12c1a34f99 100644 --- a/langchain/callbacks/tracers/__init__.py +++ b/langchain/callbacks/tracers/__init__.py @@ -1,5 +1,6 @@ """Tracers that record execution of LangChain runs.""" from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1 -__all__ = ["LangChainTracer"] +__all__ = ["LangChainTracer", "LangChainTracerV1"] diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 4516f8b5f3..f66863f9d0 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -7,15 +7,7 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.tracers.schemas import ( - ChainRun, - LLMRun, - ToolRun, - TracerSession, - TracerSessionBase, - TracerSessionCreate, - TracerSessionV2, -) +from langchain.callbacks.tracers.schemas import Run, RunTypeEnum from langchain.schema import LLMResult @@ -28,89 +20,45 @@ class BaseTracer(BaseCallbackHandler, ABC): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} - self.session: Optional[Union[TracerSession, TracerSessionV2]] = None + self.run_map: Dict[str, Run] = {} @staticmethod def _add_child_run( - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], + parent_run: Run, + child_run: Run, ) -> None: """Add child run to a chain run or tool run.""" - if isinstance(child_run, LLMRun): - parent_run.child_llm_runs.append(child_run) - elif isinstance(child_run, ChainRun): - parent_run.child_chain_runs.append(child_run) - elif isinstance(child_run, ToolRun): - parent_run.child_tool_runs.append(child_run) - else: - raise TracerException(f"Invalid run type: {type(child_run)}") + parent_run.child_runs.append(child_run) @abstractmethod - def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + def _persist_run(self, run: Run) -> None: """Persist a run.""" - @abstractmethod - def _persist_session( - self, session: TracerSessionBase - ) -> Union[TracerSession, TracerSessionV2]: - """Persist a tracing session.""" - - def _get_session_create( - self, name: Optional[str] = None, **kwargs: Any - ) -> TracerSessionBase: - return TracerSessionCreate(name=name, extra=kwargs) - - def new_session( - self, name: Optional[str] = None, **kwargs: Any - ) -> Union[TracerSession, TracerSessionV2]: - """NOT thread safe, do not call this method from multiple threads.""" - session_create = self._get_session_create(name=name, **kwargs) - session = self._persist_session(session_create) - self.session = session - return session - - @abstractmethod - def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]: - """Load a tracing session and set it as the Tracer's session.""" - - @abstractmethod - def load_default_session(self) -> Union[TracerSession, TracerSessionV2]: - """Load the default tracing session and set it as the Tracer's session.""" - - def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + def _start_trace(self, run: Run) -> None: """Start a trace for a run.""" - if run.parent_uuid: - parent_run = self.run_map[run.parent_uuid] + if run.parent_run_id: + parent_run = self.run_map[str(run.parent_run_id)] if parent_run: - if isinstance(parent_run, LLMRun): - raise TracerException( - "Cannot add child run to an LLM run. " - "LLM runs are not allowed to have children." - ) self._add_child_run(parent_run, run) else: raise TracerException( - f"Parent run with UUID {run.parent_uuid} not found." + f"Parent run with UUID {run.parent_run_id} not found." ) + self.run_map[str(run.id)] = run - self.run_map[run.uuid] = run - - def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + def _end_trace(self, run: Run) -> None: """End a trace for a run.""" - if not run.parent_uuid: + if not run.parent_run_id: self._persist_run(run) else: - parent_run = self.run_map.get(run.parent_uuid) + parent_run = self.run_map.get(str(run.parent_run_id)) if parent_run is None: raise TracerException( - f"Parent run with UUID {run.parent_uuid} not found." + f"Parent run with UUID {run.parent_run_id} not found." ) - if isinstance(parent_run, LLMRun): - raise TracerException("LLM Runs are not allowed to have children. ") if run.child_execution_order > parent_run.child_execution_order: parent_run.child_execution_order = run.child_execution_order - self.run_map.pop(run.uuid) + self.run_map.pop(str(run.id)) def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: """Get the execution order for a run.""" @@ -121,9 +69,6 @@ class BaseTracer(BaseCallbackHandler, ABC): if parent_run is None: raise TracerException(f"Parent run with UUID {parent_run_id} not found.") - if isinstance(parent_run, LLMRun): - raise TracerException("LLM Runs are not allowed to have children. ") - return parent_run.child_execution_order + 1 def on_llm_start( @@ -136,25 +81,22 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" - if self.session is None: - self.session = self.load_default_session() - - run_id_ = str(run_id) parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - llm_run = LLMRun( - uuid=run_id_, - parent_uuid=parent_run_id_, + llm_run = Run( + id=run_id, + name=serialized.get("name"), + parent_run_id=parent_run_id, serialized=serialized, - prompts=prompts, + inputs={"prompts": prompts}, extra=kwargs, start_time=datetime.utcnow(), execution_order=execution_order, child_execution_order=execution_order, - session_id=self.session.id, + run_type=RunTypeEnum.llm, ) self._start_trace(llm_run) + self._on_llm_start(llm_run) def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for an LLM run.""" @@ -163,11 +105,12 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id_ = str(run_id) llm_run = self.run_map.get(run_id_) - if llm_run is None or not isinstance(llm_run, LLMRun): - raise TracerException("No LLMRun found to be traced") - llm_run.response = response + if llm_run is None or llm_run.run_type != RunTypeEnum.llm: + raise TracerException("No LLM Run found to be traced") + llm_run.outputs = response.dict() llm_run.end_time = datetime.utcnow() self._end_trace(llm_run) + self._on_llm_end(llm_run) def on_llm_error( self, @@ -182,12 +125,12 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id_ = str(run_id) llm_run = self.run_map.get(run_id_) - if llm_run is None or not isinstance(llm_run, LLMRun): - raise TracerException("No LLMRun found to be traced") - + if llm_run is None or llm_run.run_type != RunTypeEnum.llm: + raise TracerException("No LLM Run found to be traced") llm_run.error = repr(error) llm_run.end_time = datetime.utcnow() self._end_trace(llm_run) + self._on_chain_error(llm_run) def on_chain_start( self, @@ -199,16 +142,12 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> None: """Start a trace for a chain run.""" - if self.session is None: - self.session = self.load_default_session() - - run_id_ = str(run_id) parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - chain_run = ChainRun( - uuid=run_id_, - parent_uuid=parent_run_id_, + chain_run = Run( + id=run_id, + name=serialized.get("name"), + parent_run_id=parent_run_id, serialized=serialized, inputs=inputs, extra=kwargs, @@ -216,23 +155,25 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order=execution_order, child_execution_order=execution_order, child_runs=[], - session_id=self.session.id, + run_type=RunTypeEnum.chain, ) self._start_trace(chain_run) + self._on_chain_start(chain_run) def on_chain_end( self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any ) -> None: """End a trace for a chain run.""" - run_id_ = str(run_id) - - chain_run = self.run_map.get(run_id_) - if chain_run is None or not isinstance(chain_run, ChainRun): - raise TracerException("No ChainRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_chain_end callback.") + chain_run = self.run_map.get(str(run_id)) + if chain_run is None or chain_run.run_type != RunTypeEnum.chain: + raise TracerException("No chain Run found to be traced") chain_run.outputs = outputs chain_run.end_time = datetime.utcnow() self._end_trace(chain_run) + self._on_chain_end(chain_run) def on_chain_error( self, @@ -242,15 +183,16 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> None: """Handle an error for a chain run.""" - run_id_ = str(run_id) - - chain_run = self.run_map.get(run_id_) - if chain_run is None or not isinstance(chain_run, ChainRun): - raise TracerException("No ChainRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_chain_error callback.") + chain_run = self.run_map.get(str(run_id)) + if chain_run is None or chain_run.run_type != RunTypeEnum.chain: + raise TracerException("No chain Run found to be traced") chain_run.error = repr(error) chain_run.end_time = datetime.utcnow() self._end_trace(chain_run) + self._on_chain_error(chain_run) def on_tool_start( self, @@ -262,40 +204,36 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> None: """Start a trace for a tool run.""" - if self.session is None: - self.session = self.load_default_session() - - run_id_ = str(run_id) parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - tool_run = ToolRun( - uuid=run_id_, - parent_uuid=parent_run_id_, + tool_run = Run( + id=run_id, + name=serialized.get("name"), + parent_run_id=parent_run_id, serialized=serialized, - # TODO: this is duplicate info as above, not needed. - action=str(serialized), - tool_input=input_str, + inputs={"input": input_str}, extra=kwargs, start_time=datetime.utcnow(), execution_order=execution_order, child_execution_order=execution_order, child_runs=[], - session_id=self.session.id, + run_type=RunTypeEnum.tool, ) self._start_trace(tool_run) + self._on_tool_start(tool_run) def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for a tool run.""" - run_id_ = str(run_id) + if not run_id: + raise TracerException("No run_id provided for on_tool_end callback.") + tool_run = self.run_map.get(str(run_id)) + if tool_run is None or tool_run.run_type != RunTypeEnum.tool: + raise TracerException("No tool Run found to be traced") - tool_run = self.run_map.get(run_id_) - if tool_run is None or not isinstance(tool_run, ToolRun): - raise TracerException("No ToolRun found to be traced") - - tool_run.output = output + tool_run.outputs = {"output": output} tool_run.end_time = datetime.utcnow() self._end_trace(tool_run) + self._on_tool_end(tool_run) def on_tool_error( self, @@ -305,15 +243,16 @@ class BaseTracer(BaseCallbackHandler, ABC): **kwargs: Any, ) -> None: """Handle an error for a tool run.""" - run_id_ = str(run_id) - - tool_run = self.run_map.get(run_id_) - if tool_run is None or not isinstance(tool_run, ToolRun): - raise TracerException("No ToolRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_tool_error callback.") + tool_run = self.run_map.get(str(run_id)) + if tool_run is None or tool_run.run_type != RunTypeEnum.tool: + raise TracerException("No tool Run found to be traced") tool_run.error = repr(error) tool_run.end_time = datetime.utcnow() self._end_trace(tool_run) + self._on_tool_error(tool_run) def __deepcopy__(self, memo: dict) -> BaseTracer: """Deepcopy the tracer.""" @@ -322,3 +261,33 @@ class BaseTracer(BaseCallbackHandler, ABC): def __copy__(self) -> BaseTracer: """Copy the tracer.""" return self + + def _on_llm_start(self, run: Run) -> None: + """Process the LLM Run upon start.""" + + def _on_llm_end(self, run: Run) -> None: + """Process the LLM Run.""" + + def _on_llm_error(self, run: Run) -> None: + """Process the LLM Run upon error.""" + + def _on_chain_start(self, run: Run) -> None: + """Process the Chain Run upon start.""" + + def _on_chain_end(self, run: Run) -> None: + """Process the Chain Run.""" + + def _on_chain_error(self, run: Run) -> None: + """Process the Chain Run upon error.""" + + def _on_tool_start(self, run: Run) -> None: + """Process the Tool Run upon start.""" + + def _on_tool_end(self, run: Run) -> None: + """Process the Tool Run.""" + + def _on_tool_error(self, run: Run) -> None: + """Process the Tool Run upon error.""" + + def _on_chat_model_start(self, run: Run) -> None: + """Process the Chat Model Run upon start.""" diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 65bddfcc3a..e893c07b99 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -4,27 +4,24 @@ from __future__ import annotations import logging import os from datetime import datetime -from typing import Any, Dict, List, Optional, Union -from uuid import UUID, uuid4 +from typing import Any, Dict, List, Optional +from uuid import UUID import requests from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import ( - ChainRun, - LLMRun, + Run, RunCreate, - ToolRun, + RunTypeEnum, TracerSession, - TracerSessionBase, - TracerSessionV2, - TracerSessionV2Create, + TracerSessionCreate, ) from langchain.schema import BaseMessage, messages_to_dict from langchain.utils import raise_for_status_with_text -def _get_headers() -> Dict[str, Any]: +def get_headers() -> Dict[str, Any]: """Get the headers for the LangChain API.""" headers: Dict[str, Any] = {"Content-Type": "application/json"} if os.getenv("LANGCHAIN_API_KEY"): @@ -32,168 +29,47 @@ def _get_headers() -> Dict[str, Any]: return headers -def _get_endpoint() -> str: +def get_endpoint() -> str: return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") +def _get_tenant_id( + tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict] +) -> str: + """Get the tenant ID for the LangChain API.""" + tenant_id_: Optional[str] = tenant_id or os.getenv("LANGCHAIN_TENANT_ID") + if tenant_id_: + return tenant_id_ + endpoint_ = endpoint or get_endpoint() + headers_ = headers or get_headers() + response = requests.get(endpoint_ + "/tenants", headers=headers_) + raise_for_status_with_text(response) + tenants: List[Dict[str, Any]] = response.json() + if not tenants: + raise ValueError(f"No tenants found for URL {endpoint_}") + return tenants[0]["id"] + + class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, + tenant_id: Optional[str] = None, + example_id: Optional[UUID] = None, + session_name: Optional[str] = None, + session_extra: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: """Initialize the LangChain tracer.""" super().__init__(**kwargs) - self._endpoint = _get_endpoint() - self._headers = _get_headers() - - def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: - """Persist a run.""" - if isinstance(run, LLMRun): - endpoint = f"{self._endpoint}/llm-runs" - elif isinstance(run, ChainRun): - endpoint = f"{self._endpoint}/chain-runs" - else: - endpoint = f"{self._endpoint}/tool-runs" - - try: - response = requests.post( - endpoint, - data=run.json(), - headers=self._headers, - ) - raise_for_status_with_text(response) - except Exception as e: - logging.warning(f"Failed to persist run: {e}") - - def _persist_session( - self, session_create: TracerSessionBase - ) -> Union[TracerSession, TracerSessionV2]: - """Persist a session.""" - try: - r = requests.post( - f"{self._endpoint}/sessions", - data=session_create.json(), - headers=self._headers, - ) - session = TracerSession(id=r.json()["id"], **session_create.dict()) - except Exception as e: - logging.warning(f"Failed to create session, using default session: {e}") - session = TracerSession(id=1, **session_create.dict()) - return session - - def _load_session(self, session_name: Optional[str] = None) -> TracerSession: - """Load a session from the tracer.""" - try: - url = f"{self._endpoint}/sessions" - if session_name: - url += f"?name={session_name}" - r = requests.get(url, headers=self._headers) - - tracer_session = TracerSession(**r.json()[0]) - except Exception as e: - session_type = "default" if not session_name else session_name - logging.warning( - f"Failed to load {session_type} session, using empty session: {e}" - ) - tracer_session = TracerSession(id=1) - - self.session = tracer_session - return tracer_session - - def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]: - """Load a session with the given name from the tracer.""" - return self._load_session(session_name) - - def load_default_session(self) -> Union[TracerSession, TracerSessionV2]: - """Load the default tracing session and set it as the Tracer's session.""" - return self._load_session("default") - - -def _get_tenant_id() -> Optional[str]: - """Get the tenant ID for the LangChain API.""" - tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID") - if tenant_id: - return tenant_id - endpoint = _get_endpoint() - headers = _get_headers() - response = requests.get(endpoint + "/tenants", headers=headers) - raise_for_status_with_text(response) - tenants: List[Dict[str, Any]] = response.json() - if not tenants: - raise ValueError(f"No tenants found for URL {endpoint}") - return tenants[0]["id"] - - -class LangChainTracerV2(LangChainTracer): - """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - - def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None: - """Initialize the LangChain tracer.""" - super().__init__(**kwargs) - self._endpoint = _get_endpoint() - self._headers = _get_headers() - self.tenant_id = _get_tenant_id() + self.session: Optional[TracerSession] = None + self._endpoint = get_endpoint() + self._headers = get_headers() + self.tenant_id = tenant_id self.example_id = example_id - - def _get_session_create( - self, name: Optional[str] = None, **kwargs: Any - ) -> TracerSessionBase: - return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id) - - def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2: - """Persist a session.""" - session: Optional[TracerSessionV2] = None - try: - r = requests.post( - f"{self._endpoint}/sessions", - data=session_create.json(), - headers=self._headers, - ) - raise_for_status_with_text(r) - creation_args = session_create.dict() - if "id" in creation_args: - del creation_args["id"] - return TracerSessionV2(id=r.json()["id"], **creation_args) - except Exception as e: - if session_create.name is not None: - try: - return self.load_session(session_create.name) - except Exception: - pass - logging.warning( - f"Failed to create session {session_create.name}," - f" using empty session: {e}" - ) - session = TracerSessionV2(id=uuid4(), **session_create.dict()) - - return session - - def _get_default_query_params(self) -> Dict[str, Any]: - """Get the query params for the LangChain API.""" - return {"tenant_id": self.tenant_id} - - def load_session(self, session_name: str) -> TracerSessionV2: - """Load a session with the given name from the tracer.""" - try: - url = f"{self._endpoint}/sessions" - params = {"tenant_id": self.tenant_id} - if session_name: - params["name"] = session_name - r = requests.get(url, headers=self._headers, params=params) - raise_for_status_with_text(r) - tracer_session = TracerSessionV2(**r.json()[0]) - except Exception as e: - session_type = "default" if not session_name else session_name - logging.warning( - f"Failed to load {session_type} session, using empty session: {e}" - ) - tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id) - - self.session = tracer_session - return tracer_session - - def load_default_session(self) -> TracerSessionV2: - """Load the default tracing session and set it as the Tracer's session.""" - return self.load_session("default") + self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default") + self.session_extra = session_extra def on_chat_model_start( self, @@ -205,81 +81,56 @@ class LangChainTracerV2(LangChainTracer): **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" - if self.session is None: - self.session = self.load_default_session() - - run_id_ = str(run_id) parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - llm_run = LLMRun( - uuid=run_id_, - parent_uuid=parent_run_id_, + chat_model_run = Run( + id=run_id, + name=serialized.get("name"), + parent_run_id=parent_run_id, serialized=serialized, - prompts=[], - extra={**kwargs, "messages": messages}, + inputs={"messages": messages_to_dict(batch) for batch in messages}, + extra=kwargs, start_time=datetime.utcnow(), execution_order=execution_order, child_execution_order=execution_order, - session_id=self.session.id, + run_type=RunTypeEnum.llm, ) - self._start_trace(llm_run) + self._start_trace(chat_model_run) + self._on_chat_model_start(chat_model_run) - def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate: - """Convert a run to a Run.""" - session = self.session or self.load_default_session() - inputs: Dict[str, Any] = {} - outputs: Optional[Dict[str, Any]] = None - child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] - if isinstance(run, LLMRun): - run_type = "llm" - if run.extra is not None and "messages" in run.extra: - messages: List[List[BaseMessage]] = run.extra.pop("messages") - converted_messages = [messages_to_dict(batch) for batch in messages] - inputs = {"messages": converted_messages} - else: - inputs = {"prompts": run.prompts} - outputs = run.response.dict() if run.response else {} - child_runs = [] - elif isinstance(run, ChainRun): - run_type = "chain" - inputs = run.inputs - outputs = run.outputs - child_runs = [ - *run.child_llm_runs, - *run.child_chain_runs, - *run.child_tool_runs, - ] - else: - run_type = "tool" - inputs = {"input": run.tool_input} - outputs = {"output": run.output} if run.output else {} - child_runs = [ - *run.child_llm_runs, - *run.child_chain_runs, - *run.child_tool_runs, - ] - - return RunCreate( - id=run.uuid, - name=run.serialized.get("name"), - start_time=run.start_time, - end_time=run.end_time, - extra=run.extra or {}, - error=run.error, - execution_order=run.execution_order, - serialized=run.serialized, - inputs=inputs, - outputs=outputs, - session_id=session.id, - run_type=run_type, - child_runs=[self._convert_run(child) for child in child_runs], + def ensure_tenant_id(self) -> str: + """Load or use the tenant ID.""" + tenant_id = self.tenant_id or _get_tenant_id( + self.tenant_id, self._endpoint, self._headers ) + self.tenant_id = tenant_id + return tenant_id - def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + def ensure_session(self) -> TracerSession: + """Upsert a session.""" + if self.session is not None: + return self.session + tenant_id = self.ensure_tenant_id() + url = f"{self._endpoint}/sessions?upsert=true" + session_create = TracerSessionCreate( + name=self.session_name, extra=self.session_extra, tenant_id=tenant_id + ) + r = requests.post( + url, + data=session_create.json(), + headers=self._headers, + ) + raise_for_status_with_text(r) + self.session = TracerSession(**r.json()) + return self.session + + def _persist_run_nested(self, run: Run) -> None: """Persist a run.""" - run_create = self._convert_run(run) - run_create.reference_example_id = self.example_id + session = self.ensure_session() + child_runs = run.child_runs + run_dict = run.dict() + del run_dict["child_runs"] + run_create = RunCreate(**run_dict, session_id=session.id) try: response = requests.post( f"{self._endpoint}/runs", @@ -289,3 +140,12 @@ class LangChainTracerV2(LangChainTracer): raise_for_status_with_text(response) except Exception as e: logging.warning(f"Failed to persist run: {e}") + for child_run in child_runs: + child_run.parent_run_id = run.id + self._persist_run_nested(child_run) + + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + run.reference_example_id = self.example_id + # TODO: Post first then patch + self._persist_run_nested(run) diff --git a/langchain/callbacks/tracers/langchain_v1.py b/langchain/callbacks/tracers/langchain_v1.py new file mode 100644 index 0000000000..31b7630dab --- /dev/null +++ b/langchain/callbacks/tracers/langchain_v1.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import logging +from typing import Any, Optional, Union + +import requests + +from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.langchain import get_endpoint, get_headers +from langchain.callbacks.tracers.schemas import ( + ChainRun, + LLMRun, + Run, + ToolRun, + TracerSession, + TracerSessionV1, + TracerSessionV1Base, +) +from langchain.schema import get_buffer_string +from langchain.utils import raise_for_status_with_text + + +class LangChainTracerV1(BaseTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.session: Optional[TracerSessionV1] = None + self._endpoint = get_endpoint() + self._headers = get_headers() + + def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]: + session = self.session or self.load_default_session() + if not isinstance(session, TracerSessionV1): + raise ValueError( + "LangChainTracerV1 is not compatible with" + f" session of type {type(session)}" + ) + + if run.run_type == "llm": + if "prompts" in run.inputs: + prompts = run.inputs["prompts"] + elif "messages" in run.inputs: + prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]] + else: + raise ValueError("No prompts found in LLM run inputs") + return LLMRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + extra=run.extra, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + error=run.error, + prompts=prompts, + response=run.outputs if run.outputs else None, + ) + if run.run_type == "chain": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ChainRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + inputs=run.inputs, + outputs=run.outputs, + error=run.error, + extra=run.extra, + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + ) + if run.run_type == "tool": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ToolRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + action=str(run.serialized), + tool_input=run.inputs.get("input", ""), + output=None if run.outputs is None else run.outputs.get("output"), + error=run.error, + extra=run.extra, + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + ) + raise ValueError(f"Unknown run type: {run.run_type}") + + def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, Run): + v1_run = self._convert_to_v1_run(run) + else: + v1_run = run + if isinstance(v1_run, LLMRun): + endpoint = f"{self._endpoint}/llm-runs" + elif isinstance(v1_run, ChainRun): + endpoint = f"{self._endpoint}/chain-runs" + else: + endpoint = f"{self._endpoint}/tool-runs" + + try: + response = requests.post( + endpoint, + data=v1_run.json(), + headers=self._headers, + ) + raise_for_status_with_text(response) + except Exception as e: + logging.warning(f"Failed to persist run: {e}") + + def _persist_session( + self, session_create: TracerSessionV1Base + ) -> Union[TracerSessionV1, TracerSession]: + """Persist a session.""" + try: + r = requests.post( + f"{self._endpoint}/sessions", + data=session_create.json(), + headers=self._headers, + ) + session = TracerSessionV1(id=r.json()["id"], **session_create.dict()) + except Exception as e: + logging.warning(f"Failed to create session, using default session: {e}") + session = TracerSessionV1(id=1, **session_create.dict()) + return session + + def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1: + """Load a session from the tracer.""" + try: + url = f"{self._endpoint}/sessions" + if session_name: + url += f"?name={session_name}" + r = requests.get(url, headers=self._headers) + + tracer_session = TracerSessionV1(**r.json()[0]) + except Exception as e: + session_type = "default" if not session_name else session_name + logging.warning( + f"Failed to load {session_type} session, using empty session: {e}" + ) + tracer_session = TracerSessionV1(id=1) + + self.session = tracer_session + return tracer_session + + def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]: + """Load a session with the given name from the tracer.""" + return self._load_session(session_name) + + def load_default_session(self) -> Union[TracerSessionV1, TracerSession]: + """Load the default tracing session and set it as the Tracer's session.""" + return self._load_session("default") diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index 39a0ab0ea4..221cd14c51 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -6,47 +6,45 @@ from enum import Enum from typing import Any, Dict, List, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, root_validator from langchain.schema import LLMResult -class TracerSessionBase(BaseModel): - """Base class for TracerSession.""" +class TracerSessionV1Base(BaseModel): + """Base class for TracerSessionV1.""" start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) name: Optional[str] = None extra: Optional[Dict[str, Any]] = None -class TracerSessionCreate(TracerSessionBase): - """Create class for TracerSession.""" +class TracerSessionV1Create(TracerSessionV1Base): + """Create class for TracerSessionV1.""" pass -class TracerSession(TracerSessionBase): - """TracerSession schema.""" +class TracerSessionV1(TracerSessionV1Base): + """TracerSessionV1 schema.""" id: int -class TracerSessionV2Base(TracerSessionBase): - """A creation class for TracerSessionV2.""" +class TracerSessionBase(TracerSessionV1Base): + """A creation class for TracerSession.""" tenant_id: UUID -class TracerSessionV2Create(TracerSessionV2Base): - """A creation class for TracerSessionV2.""" +class TracerSessionCreate(TracerSessionBase): + """A creation class for TracerSession.""" id: Optional[UUID] - pass - -class TracerSessionV2(TracerSessionV2Base): - """TracerSession schema for the V2 API.""" +class TracerSession(TracerSessionBase): + """TracerSessionV1 schema for the V2 API.""" id: UUID @@ -111,26 +109,32 @@ class RunBase(BaseModel): extra: dict error: Optional[str] execution_order: int + child_execution_order: int serialized: dict inputs: dict outputs: Optional[dict] - session_id: UUID reference_example_id: Optional[UUID] run_type: RunTypeEnum parent_run_id: Optional[UUID] -class RunCreate(RunBase): - """Schema to create a run in the DB.""" - - name: Optional[str] - child_runs: List[RunCreate] = Field(default_factory=list) - - class Run(RunBase): """Run schema when loading from the DB.""" name: str + child_runs: List[Run] = Field(default_factory=list) + + @root_validator(pre=True) + def assign_name(cls, values: dict) -> dict: + """Assign name to the run.""" + if "name" not in values: + values["name"] = values["serialized"]["name"] + return values + + +class RunCreate(RunBase): + name: str + session_id: UUID ChainRun.update_forward_refs() diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 8df0219fa8..de330028f6 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -27,7 +27,7 @@ from requests import Response from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import tracing_v2_enabled -from langchain.callbacks.tracers.langchain import LangChainTracerV2 +from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate @@ -308,7 +308,7 @@ class LangChainPlusClient(BaseSettings): async def _arun_llm( llm: BaseLanguageModel, inputs: Dict[str, Any], - langchain_tracer: LangChainTracerV2, + langchain_tracer: LangChainTracer, ) -> Union[LLMResult, ChatResult]: if isinstance(llm, BaseLLM): if "prompt" not in inputs: @@ -328,7 +328,7 @@ class LangChainPlusClient(BaseSettings): @staticmethod async def _arun_llm_or_chain( example: Example, - langchain_tracer: LangChainTracerV2, + langchain_tracer: LangChainTracer, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: @@ -358,8 +358,8 @@ class LangChainPlusClient(BaseSettings): @staticmethod async def _gather_with_concurrency( n: int, - initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracerV2, Dict]]], - *async_funcs: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]], + initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]], + *async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]], ) -> List[Any]: """ Run coroutines with a concurrency limit. @@ -376,7 +376,7 @@ class LangChainPlusClient(BaseSettings): tracer, job_state = await initializer() async def run_coroutine_with_semaphore( - async_func: Callable[[LangChainTracerV2, Dict], Coroutine[Any, Any, Any]] + async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]] ) -> Any: async with semaphore: return await async_func(tracer, job_state) @@ -387,7 +387,7 @@ class LangChainPlusClient(BaseSettings): async def _tracer_initializer( self, session_name: str - ) -> Tuple[LangChainTracerV2, dict]: + ) -> Tuple[LangChainTracer, dict]: """ Initialize a tracer to share across tasks. @@ -395,11 +395,11 @@ class LangChainPlusClient(BaseSettings): session_name: The session name for the tracer. Returns: - A LangChainTracerV2 instance with an active session. + A LangChainTracer instance with an active session. """ job_state = {"num_processed": 0} with tracing_v2_enabled(session_name=session_name) as session: - tracer = LangChainTracerV2() + tracer = LangChainTracer() tracer.session = session return tracer, job_state @@ -440,7 +440,7 @@ class LangChainPlusClient(BaseSettings): results: Dict[str, List[Any]] = {} async def process_example( - example: Example, tracer: LangChainTracerV2, job_state: dict + example: Example, tracer: LangChainTracer, job_state: dict ) -> None: """Process a single example.""" result = await LangChainPlusClient._arun_llm_or_chain( @@ -469,7 +469,7 @@ class LangChainPlusClient(BaseSettings): def run_llm( llm: BaseLanguageModel, inputs: Dict[str, Any], - langchain_tracer: LangChainTracerV2, + langchain_tracer: LangChainTracer, ) -> Union[LLMResult, ChatResult]: """Run the language model on the example.""" if isinstance(llm, BaseLLM): @@ -492,7 +492,7 @@ class LangChainPlusClient(BaseSettings): @staticmethod def run_llm_or_chain( example: Example, - langchain_tracer: LangChainTracerV2, + langchain_tracer: LangChainTracer, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: @@ -551,7 +551,7 @@ class LangChainPlusClient(BaseSettings): examples = list(self.list_examples(dataset_id=str(dataset.id))) results: Dict[str, Any] = {} with tracing_v2_enabled(session_name=session_name) as session: - tracer = LangChainTracerV2() + tracer = LangChainTracer() tracer.session = session for i, example in enumerate(examples): diff --git a/langchain/experimental/client/tracing_datasets.ipynb b/langchain/experimental/client/tracing_datasets.ipynb index dddbc52d63..dfb1b8198c 100644 --- a/langchain/experimental/client/tracing_datasets.ipynb +++ b/langchain/experimental/client/tracing_datasets.ipynb @@ -1,1128 +1,1109 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "1a4596ea-a631-416d-a2a4-3577c140493d", - "metadata": {}, - "source": [ - "# Running Chains on Traced Datasets\n", - "\n", - "Developing applications with language models can be uniquely challenging. To manage this complexity and ensure reliable performance, LangChain provides tracing and evaluation functionality. This notebook demonstrates how to run Chains, which are language model functions, as well as Chat models, and LLMs on previously captured datasets or traces. Some common use cases for this approach include:\n", - "\n", - "- Running an evaluation chain to grade previous runs.\n", - "- Comparing different chains, LLMs, and agents on traced datasets.\n", - "- Executing a stochastic chain multiple times over a dataset to generate metrics before deployment.\n", - "\n", - "Please note that this notebook assumes you have LangChain+ tracing running in the background. It is also configured to work only with the V2 endpoints. To set it up, follow the [tracing directions here](..\\/..\\/tracing\\/local_installation.md).\n", - " \n", - "We'll start by creating a client to connect to LangChain+." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "904db9a5-f387-4a57-914c-c8af8d39e249", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "You can click the link below to view the UI\n" - ] - }, - { - "data": { - "text/html": [ - "LangChain+ Client" - ], - "text/plain": [ - "LangChainPlusClient (API URL: http://localhost:8000)" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain.client import LangChainPlusClient\n", - "\n", - "client = LangChainPlusClient(\n", - " api_url=\"http://localhost:8000\",\n", - " api_key=None,\n", - " # tenant_id=\"your_tenant_uuid\", # This is required when connecting to a hosted LangChain instance\n", - ")\n", - "print(\"You can click the link below to view the UI\")\n", - "client" - ] - }, - { - "cell_type": "markdown", - "id": "2d77d064-41b4-41fb-82e6-2d16461269ec", - "metadata": { - "tags": [] - }, - "source": [ - "## Capture traces\n", - "\n", - "If you have been using LangChainPlus already, you may have datasets available. To view all saved datasets, run:\n", - "\n", - "```\n", - "datasets = client.list_datasets()\n", - "print(datasets)\n", - "```\n", - "\n", - "Datasets can be created in a number of ways, most often by collecting `Run`'s captured through the LangChain tracing API and converting a set of runs to a dataset.\n", - "\n", - "The V2 tracing API is currently accessible using the `tracing_v2_enabled` context manager. Assuming the server was succesfully started above, running LangChain Agents, Chains, LLMs, and other primitives will then automatically capture traces. We'll start with a simple math example.\n", - "\n", - "**Note** You can also use the `LANGCHAIN_TRACING_V2` environment variable to enable tracing for all runs by default, regardless of whether or not those runs happen within the `tracing_v2_enabled` context manager (i.e. `os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"`)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4417e0b8-a26f-4a11-b7eb-ba7a18e73885", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.callbacks.manager import tracing_v2_enabled" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "7c801853-8e96-404d-984c-51ace59cbbef", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.chat_models import ChatOpenAI\n", - "from langchain.agents import initialize_agent, load_tools\n", - "from langchain.agents import AgentType\n", - "\n", - "llm = ChatOpenAI(temperature=0)\n", - "tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n", - "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "19537902-b95c-4390-80a4-f6c9a937081e", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The current population of Canada as of 2023 is 39,566,248.\n", - "Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\n", - "LLMMathChain._evaluate(\"\n", - "(age)**0.43\n", - "\") raised error: 'age'. Please try again with a valid numerical expression\n", - "The distance between Paris and Boston is approximately 3448 miles.\n", - "unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n", - "LLMMathChain._evaluate(\"\n", - "(total number of points scored in the 2023 super bowl)**0.23\n", - "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n", - "Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n", - "1.9347796717823205\n", - "77\n", - "0.2791714614499425\n" - ] - } - ], - "source": [ - "inputs = [\n", - "'How many people live in canada as of 2023?',\n", - " \"who is dua lipa's boyfriend? what is his age raised to the .43 power?\",\n", - " \"what is dua lipa's boyfriend age raised to the .43 power?\",\n", - " 'how far is it from paris to boston in miles',\n", - " 'what was the total number of points scored in the 2023 super bowl? what is that number raised to the .23 power?',\n", - " 'what was the total number of points scored in the 2023 super bowl raised to the .23 power?',\n", - " 'how many more points were scored in the 2023 super bowl than in the 2022 super bowl?',\n", - " 'what is 153 raised to .1312 power?',\n", - " \"who is kendall jenner's boyfriend? what is his height (in inches) raised to .13 power?\",\n", - " 'what is 1213 divided by 4345?'\n", - "]\n", - "with tracing_v2_enabled(session_name=\"search_and_math_chain\"):\n", - " for input_example in inputs:\n", - " try:\n", - " print(agent.run(input_example))\n", - " except Exception as e:\n", - " # The agent sometimes makes mistakes! These will be captured by the tracing.\n", - " print(e)\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "6c43c311-4e09-4d57-9ef3-13afb96ff430", - "metadata": {}, - "source": [ - "## Creating the Dataset\n", - "\n", - "Now that you've captured a session entitled 'search_and_math_chain', it's time to create a dataset:\n", - "\n", - " 1. Navigate to the UI by clicking on the link below.\n", - " 2. Select the 'search_and_math_chain' session from the list.\n", - " 3. Next to the fist example, click \"+ to Dataset\".\n", - " 4. Click \"Create Dataset\" and create a title **\"calculator-example-dataset\"**.\n", - " 5. Add the other examples to the dataset as well" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "d14a9881-2a01-404c-8c56-0b78565c3ff4", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "dataset_name = \"calculator-example-dataset\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "7bfb4699-62c3-400a-b3e7-7fb8ad3a68ad", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "LangChain+ Client" - ], - "text/plain": [ - "LangChainPlusClient (API URL: http://localhost:8000)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client" - ] - }, - { - "cell_type": "markdown", - "id": "db79dea2-fbaa-4c12-9083-f6154b51e2d3", - "metadata": { - "jp-MarkdownHeadingCollapsed": true, - "tags": [] - }, - "source": [ - "**Optional:** If you didn't run the trace above, you can also create datasets by uploading dataframes or CSV files." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "1baa677c-5642-4378-8e01-3aa1647f19d6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# !pip install datasets > /dev/null\n", - "# !pip install pandas > /dev/null" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "60d14593-c61f-449f-a38f-772ca43707c2", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c34edde8de5340888b3278d1ac427417", - "version_major": 2, - "version_minor": 0 + "cells": [{ + "cell_type": "markdown", + "id": "1a4596ea-a631-416d-a2a4-3577c140493d", + "metadata": {}, + "source": [ + "# Running Chains on Traced Datasets\n", + "\n", + "Developing applications with language models can be uniquely challenging. To manage this complexity and ensure reliable performance, LangChain provides tracing and evaluation functionality. This notebook demonstrates how to run Chains, which are language model functions, as well as Chat models, and LLMs on previously captured datasets or traces. Some common use cases for this approach include:\n", + "\n", + "- Running an evaluation chain to grade previous runs.\n", + "- Comparing different chains, LLMs, and agents on traced datasets.\n", + "- Executing a stochastic chain multiple times over a dataset to generate metrics before deployment.\n", + "\n", + "Please note that this notebook assumes you have LangChain+ tracing running in the background. It is also configured to work only with the V2 endpoints. To set it up, follow the [tracing directions here](..\\/..\\/tracing\\/local_installation.md).\n", + " \n", + "We'll start by creating a client to connect to LangChain+." + ] }, - "text/plain": [ - " 0%| | 0/1 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
inputoutput
0How many people live in canada as of 2023?approximately 38,625,801
1who is dua lipa's boyfriend? what is his age r...her boyfriend is Romain Gravas. his age raised...
2what is dua lipa's boyfriend age raised to the...her boyfriend is Romain Gravas. his age raised...
3how far is it from paris to boston in milesapproximately 3,435 mi
4what was the total number of points scored in ...approximately 2.682651500990882
\n", - "" - ], - "text/plain": [ - " input \\\n", - "0 How many people live in canada as of 2023? \n", - "1 who is dua lipa's boyfriend? what is his age r... \n", - "2 what is dua lipa's boyfriend age raised to the... \n", - "3 how far is it from paris to boston in miles \n", - "4 what was the total number of points scored in ... \n", - "\n", - " output \n", - "0 approximately 38,625,801 \n", - "1 her boyfriend is Romain Gravas. his age raised... \n", - "2 her boyfriend is Romain Gravas. his age raised... \n", - "3 approximately 3,435 mi \n", - "4 approximately 2.682651500990882 " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# import pandas as pd\n", - "# from langchain.evaluation.loading import load_dataset\n", - "\n", - "# dataset = load_dataset(\"agent-search-calculator\")\n", - "# df = pd.DataFrame(dataset, columns=[\"question\", \"answer\"])\n", - "# df.columns = [\"input\", \"output\"] # The chain we want to evaluate below expects inputs with the \"input\" key \n", - "# df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "52a7ea76-79ca-4765-abf7-231e884040d6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# dataset_name = \"calculator-example-dataset\"\n", - "\n", - "# if dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n", - "# dataset = client.upload_dataframe(df, \n", - "# name=dataset_name,\n", - "# description=\"A calculator example dataset\",\n", - "# input_keys=[\"input\"],\n", - "# output_keys=[\"output\"],\n", - "# )" - ] - }, - { - "cell_type": "markdown", - "id": "07885b10", - "metadata": { - "tags": [] - }, - "source": [ - "## Running a Chain on a Traced Dataset\n", - "\n", - "Once you have a dataset, you can run a compatible chain or other object over it to see its results. The run traces will automatically be associated with the dataset for easy attribution and analysis.\n", - "\n", - "**First, we'll define the chain we wish to run over the dataset.**\n", - "\n", - "In this case, we're using an agent, but it can be any simple chain." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c2b59104-b90e-466a-b7ea-c5bd0194263b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.chat_models import ChatOpenAI\n", - "from langchain.agents import initialize_agent, load_tools\n", - "from langchain.agents import AgentType\n", - "\n", - "llm = ChatOpenAI(temperature=0)\n", - "tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n", - "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)" - ] - }, - { - "cell_type": "markdown", - "id": "84094a4a-1d76-461c-bc37-8c537939b466", - "metadata": {}, - "source": [ - "**Now we're ready to run the chain!**\n", - "\n", - "The docstring below hints ways you can configure the method to run." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u001b[0;31mSignature:\u001b[0m\n", - "\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mconcurrency_level\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mDocstring:\u001b[0m\n", - "Run the chain on a dataset and store traces to the specified session name.\n", - "\n", - "Args:\n", - " dataset_name: Name of the dataset to run the chain on.\n", - " llm_or_chain_factory: Language model or Chain constructor to run\n", - " over the dataset. The Chain constructor is used to permit\n", - " independent calls on each example without carrying over state.\n", - " concurrency_level: The number of async tasks to run concurrently.\n", - " num_repetitions: Number of times to run the model on each example.\n", - " This is useful when testing success rates or generating confidence\n", - " intervals.\n", - " session_name: Name of the session to store the traces in.\n", - " Defaults to {dataset_name}-{chain class name}-{datetime}.\n", - " verbose: Whether to print progress.\n", - "\n", - "Returns:\n", - " A dictionary mapping example ids to the model outputs.\n", - "\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/langchain.py\n", - "\u001b[0;31mType:\u001b[0m method" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "?client.arun_on_dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "6e10f823", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Since chains can be stateful (e.g. they can have memory), we need provide\n", - "# a way to initialize a new chain for each row in the dataset. This is done\n", - "# by passing in a factory function that returns a new chain for each row.\n", - "chain_factory = lambda: initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n", - "\n", - "# If your chain is NOT stateful, your lambda can return the object directly\n", - "# to improve runtime performance. For example:\n", - "# chain_factory = lambda: agent" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", - " warnings.warn(\n", - "Chain failed for example 5523e460-6bb4-4a64-be37-bec0a98699a4. Error: LLMMathChain._evaluate(\"\n", - "(total number of points scored in the 2023 super bowl)**0.23\n", - "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processed examples: 2\r" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Chain failed for example f193a3f6-1147-4ce6-a83e-fab1157dc88d. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processed examples: 6\r" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Chain failed for example 6d7bbb45-1dc0-4adc-be21-4f76a208a8d2. Error: LLMMathChain._evaluate(\"\n", - "(age ** 0.43)\n", - "\") raised error: 'age'. Please try again with a valid numerical expression\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processed examples: 10\r" - ] - } - ], - "source": [ - "chain_results = await client.arun_on_dataset(\n", - " dataset_name=dataset_name,\n", - " llm_or_chain_factory=chain_factory,\n", - " verbose=True\n", - ")\n", - "\n", - "# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n", - "# These are logged as warnings here and captured as errors in the tracing UI." - ] - }, - { - "cell_type": "markdown", - "id": "d2737458-b20c-4288-8790-1f4a8d237b2a", - "metadata": {}, - "source": [ - "## Reviewing the Chain Results\n", - "\n", - "You can review the results of the run in the tracing UI below and navigating to the session \n", - "with the title 'calculator-example-dataset-AgentExecutor-YYYY-MM-DD-HH-MM-SS'" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "136db492-d6ca-4215-96f9-439c23538241", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "LangChain+ Client" - ], - "text/plain": [ - "LangChainPlusClient (API URL: http://localhost:8000)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# You can navigate to the UI by clicking on the link below\n", - "client" - ] - }, - { - "cell_type": "markdown", - "id": "c70cceb5-aa53-4851-bb12-386f092191f9", - "metadata": {}, - "source": [ - "### Running a Chat Model over a Traced Dataset\n", - "\n", - "We've shown how to run a _chain_ over a dataset, but you can also run an LLM or Chat model over a datasets formed from runs. \n", - "\n", - "First, we'll show an example using a ChatModel. This is useful for things like:\n", - "- Comparing results under different decoding parameters\n", - "- Comparing model providers\n", - "- Testing for regressions in model behavior\n", - "- Running multiple times with a temperature to gauge stability \n", - "\n", - "To speed things up, we'll upload a dataset we've previously captured directly to the tracing service." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "64490d7c-9a18-49ed-a3ac-36049c522cb4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--two-player-dnd-cc62c3037e2d9250/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "047a8094367f43938f74e863b3e01711", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "code", + "execution_count": 18, + "id": "904db9a5-f387-4a57-914c-c8af8d39e249", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stdout", + "output_type": "stream", + "text": [ + "You can click the link below to view the UI\n" + ] + }, + { + "data": { + "text/html": [ + "LangChain+ Client" + ], + "text/plain": [ + "LangChainPlusClient (API URL: http://localhost:8000)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.client import LangChainPlusClient\n", + "\n", + "client = LangChainPlusClient(\n", + " api_url=\"http://localhost:8000\",\n", + " api_key=None,\n", + " # tenant_id=\"your_tenant_uuid\", # This is required when connecting to a hosted LangChain instance\n", + ")\n", + "print(\"You can click the link below to view the UI\")\n", + "client" + ] }, - "text/plain": [ - " 0%| | 0/1 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
generationsmessages
0[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
1[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
2[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
3[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
4[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
\n", - "" - ], - "text/plain": [ - " generations \\\n", - "0 [[{'generation_info': None, 'message': {'conte... \n", - "1 [[{'generation_info': None, 'message': {'conte... \n", - "2 [[{'generation_info': None, 'message': {'conte... \n", - "3 [[{'generation_info': None, 'message': {'conte... \n", - "4 [[{'generation_info': None, 'message': {'conte... \n", - "\n", - " messages \n", - "0 [{'data': {'content': 'Here is the topic for a... \n", - "1 [{'data': {'content': 'Here is the topic for a... \n", - "2 [{'data': {'content': 'Here is the topic for a... \n", - "3 [{'data': {'content': 'Here is the topic for a... \n", - "4 [{'data': {'content': 'Here is the topic for a... " - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import pandas as pd\n", - "from langchain.evaluation.loading import load_dataset\n", - "\n", - "chat_dataset = load_dataset(\"two-player-dnd\")\n", - "chat_df = pd.DataFrame(chat_dataset)\n", - "chat_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "348acd86-a927-4d60-8d52-02e64585e4fc", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "chat_dataset_name = \"two-player-dnd\"\n", - "\n", - "if chat_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n", - " client.upload_dataframe(chat_df, \n", - " name=chat_dataset_name,\n", - " description=\"An example dataset traced from chat models in a multiagent bidding dialogue\",\n", - " input_keys=[\"messages\"],\n", - " output_keys=[\"generations\"],\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "927a43b8-e4f9-4220-b75d-33e310bc318b", - "metadata": {}, - "source": [ - "#### Reviewing behavior with temperature\n", - "\n", - "Here, we will set `num_repetitions > 1` and set the temperature to 0.3 to see the variety of response types for a each example.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "a69dd183-ad5e-473d-b631-db90706e837f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.chat_models import ChatAnthropic\n", - "\n", - "chat_model = ChatAnthropic(temperature=.3)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "063da2a9-3692-4b7b-8edb-e474824fe416", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processed examples: 36\r" - ] - } - ], - "source": [ - "chat_model_results = await client.arun_on_dataset(\n", - " dataset_name=chat_dataset_name,\n", - " llm_or_chain_factory=chat_model,\n", - " concurrency_level=5, # Optional, sets the number of examples to run at a time\n", - " num_repetitions=3,\n", - " verbose=True\n", - ")\n", - "\n", - "# The 'experimental tracing v2' warning is expected, as we are still actively developing the v2 tracing API \n", - "# Since we are running examples concurrently, you may run into some RateLimit warnings from your model\n", - "# provider. In most cases, the tests will still run to completion (the wrappers have backoff)." - ] - }, - { - "cell_type": "markdown", - "id": "de7bfe08-215c-4328-b9b0-631d9a41f0e8", - "metadata": { - "tags": [] - }, - "source": [ - "## Reviewing the Chat Model Results\n", - "\n", - "You can review the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "5b7a81f2-d19d-438b-a4bb-5678f746b965", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "LangChain+ Client" - ], - "text/plain": [ - "LangChainPlusClient (API URL: http://localhost:8000)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client" - ] - }, - { - "cell_type": "markdown", - "id": "7896cbeb-345f-430b-ab5e-e108973174f8", - "metadata": {}, - "source": [ - "## Running an LLM over a Traced Dataset\n", - "\n", - "You can run an LLM over a dataset in much the same way as the chain and chat models, provided the dataset you've captured is in the appropriate format. We've cached one for you here, but using application-specific traces will be much more useful for your use cases." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "d6805d0b-4612-4671-bffb-e6978992bd40", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.llms import OpenAI\n", - "\n", - "llm = OpenAI(model_name='text-curie-001', temperature=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "5d7cb243-40c3-44dd-8158-a7b910441e9f", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--state-of-the-union-completions-a7eb4af13453cd35/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "189832bd50114f129fb58e590d6e8267", - "version_major": 2, - "version_minor": 0 + { + "cell_type": "markdown", + "id": "2d77d064-41b4-41fb-82e6-2d16461269ec", + "metadata": { + "tags": [] + }, + "source": [ + "## Capture traces\n", + "\n", + "If you have been using LangChainPlus already, you may have datasets available. To view all saved datasets, run:\n", + "\n", + "```\n", + "datasets = client.list_datasets()\n", + "print(datasets)\n", + "```\n", + "\n", + "Datasets can be created in a number of ways, most often by collecting `Run`'s captured through the LangChain tracing API and converting a set of runs to a dataset.\n", + "\n", + "The V2 tracing API is currently accessible using the `tracing_v2_enabled` context manager. Assuming the server was succesfully started above, running LangChain Agents, Chains, LLMs, and other primitives will then automatically capture traces. We'll start with a simple math example.\n", + "\n", + "**Note** You can also use the `LANGCHAIN_TRACING_V2` environment variable to enable tracing for all runs by default, regardless of whether or not those runs happen within the `tracing_v2_enabled` context manager (i.e. `os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"`)" + ] }, - "text/plain": [ - " 0%| | 0/1 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
generationsground_truthprompt
0[[{'generation_info': {'finish_reason': 'stop'...The pandemic has been punishing. \\n\\nAnd so ma...Putin may circle Kyiv with tanks, but he will ...
1[[]]With a duty to one another to the American peo...Madam Speaker, Madam Vice President, our First...
2[[{'generation_info': {'finish_reason': 'stop'...He thought he could roll into Ukraine and the ...With a duty to one another to the American peo...
3[[{'generation_info': {'finish_reason': 'lengt...With a duty to one another to the American peo...Madam Speaker, Madam Vice President, our First...
4[[]]And the costs and the threats to America and t...Please rise if you are able and show that, Yes...
\n", - "" - ], - "text/plain": [ - " generations \\\n", - "0 [[{'generation_info': {'finish_reason': 'stop'... \n", - "1 [[]] \n", - "2 [[{'generation_info': {'finish_reason': 'stop'... \n", - "3 [[{'generation_info': {'finish_reason': 'lengt... \n", - "4 [[]] \n", - "\n", - " ground_truth \\\n", - "0 The pandemic has been punishing. \\n\\nAnd so ma... \n", - "1 With a duty to one another to the American peo... \n", - "2 He thought he could roll into Ukraine and the ... \n", - "3 With a duty to one another to the American peo... \n", - "4 And the costs and the threats to America and t... \n", - "\n", - " prompt \n", - "0 Putin may circle Kyiv with tanks, but he will ... \n", - "1 Madam Speaker, Madam Vice President, our First... \n", - "2 With a duty to one another to the American peo... \n", - "3 Madam Speaker, Madam Vice President, our First... \n", - "4 Please rise if you are able and show that, Yes... " - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } + { + "cell_type": "code", + "execution_count": 19, + "id": "4417e0b8-a26f-4a11-b7eb-ba7a18e73885", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.callbacks.manager import tracing_v2_enabled" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7c801853-8e96-404d-984c-51ace59cbbef", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents import initialize_agent, load_tools\n", + "from langchain.agents import AgentType\n", + "\n", + "llm = ChatOpenAI(temperature=0)\n", + "tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n", + "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "19537902-b95c-4390-80a4-f6c9a937081e", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The current population of Canada as of 2023 is 39,566,248.\n", + "Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\n", + "LLMMathChain._evaluate(\"\n", + "(age)**0.43\n", + "\") raised error: 'age'. Please try again with a valid numerical expression\n", + "The distance between Paris and Boston is approximately 3448 miles.\n", + "unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n", + "LLMMathChain._evaluate(\"\n", + "(total number of points scored in the 2023 super bowl)**0.23\n", + "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n", + "Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n", + "1.9347796717823205\n", + "77\n", + "0.2791714614499425\n" + ] + } + ], + "source": [ + "inputs = [\n", + "'How many people live in canada as of 2023?',\n", + " \"who is dua lipa's boyfriend? what is his age raised to the .43 power?\",\n", + " \"what is dua lipa's boyfriend age raised to the .43 power?\",\n", + " 'how far is it from paris to boston in miles',\n", + " 'what was the total number of points scored in the 2023 super bowl? what is that number raised to the .23 power?',\n", + " 'what was the total number of points scored in the 2023 super bowl raised to the .23 power?',\n", + " 'how many more points were scored in the 2023 super bowl than in the 2022 super bowl?',\n", + " 'what is 153 raised to .1312 power?',\n", + " \"who is kendall jenner's boyfriend? what is his height (in inches) raised to .13 power?\",\n", + " 'what is 1213 divided by 4345?'\n", + "]\n", + "with tracing_v2_enabled(session_name=\"search_and_math_chain\"):\n", + " for input_example in inputs:\n", + " try:\n", + " print(agent.run(input_example))\n", + " except Exception as e:\n", + " # The agent sometimes makes mistakes! These will be captured by the tracing.\n", + " print(e)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "6c43c311-4e09-4d57-9ef3-13afb96ff430", + "metadata": {}, + "source": [ + "## Creating the Dataset\n", + "\n", + "Now that you've captured a session entitled 'search_and_math_chain', it's time to create a dataset:\n", + "\n", + " 1. Navigate to the UI by clicking on the link below.\n", + " 2. Select the 'search_and_math_chain' session from the list.\n", + " 3. Next to the fist example, click \"+ to Dataset\".\n", + " 4. Click \"Create Dataset\" and create a title **\"calculator-example-dataset\"**.\n", + " 5. Add the other examples to the dataset as well" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d14a9881-2a01-404c-8c56-0b78565c3ff4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dataset_name = \"calculator-example-dataset\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7bfb4699-62c3-400a-b3e7-7fb8ad3a68ad", + "metadata": { + "tags": [] + }, + "outputs": [{ + "data": { + "text/html": [ + "LangChain+ Client" + ], + "text/plain": [ + "LangChainPlusClient (API URL: http://localhost:8000)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }], + "source": [ + "client" + ] + }, + { + "cell_type": "markdown", + "id": "db79dea2-fbaa-4c12-9083-f6154b51e2d3", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "**Optional:** If you didn't run the trace above, you can also create datasets by uploading dataframes or CSV files." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1baa677c-5642-4378-8e01-3aa1647f19d6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# !pip install datasets > /dev/null\n", + "# !pip install pandas > /dev/null" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "60d14593-c61f-449f-a38f-772ca43707c2", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c34edde8de5340888b3278d1ac427417", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
inputoutput
0How many people live in canada as of 2023?approximately 38,625,801
1who is dua lipa's boyfriend? what is his age r...her boyfriend is Romain Gravas. his age raised...
2what is dua lipa's boyfriend age raised to the...her boyfriend is Romain Gravas. his age raised...
3how far is it from paris to boston in milesapproximately 3,435 mi
4what was the total number of points scored in ...approximately 2.682651500990882
\n", + "" + ], + "text/plain": [ + " input \\\n", + "0 How many people live in canada as of 2023? \n", + "1 who is dua lipa's boyfriend? what is his age r... \n", + "2 what is dua lipa's boyfriend age raised to the... \n", + "3 how far is it from paris to boston in miles \n", + "4 what was the total number of points scored in ... \n", + "\n", + " output \n", + "0 approximately 38,625,801 \n", + "1 her boyfriend is Romain Gravas. his age raised... \n", + "2 her boyfriend is Romain Gravas. his age raised... \n", + "3 approximately 3,435 mi \n", + "4 approximately 2.682651500990882 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# import pandas as pd\n", + "# from langchain.evaluation.loading import load_dataset\n", + "\n", + "# dataset = load_dataset(\"agent-search-calculator\")\n", + "# df = pd.DataFrame(dataset, columns=[\"question\", \"answer\"])\n", + "# df.columns = [\"input\", \"output\"] # The chain we want to evaluate below expects inputs with the \"input\" key \n", + "# df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "52a7ea76-79ca-4765-abf7-231e884040d6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# dataset_name = \"calculator-example-dataset\"\n", + "\n", + "# if dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n", + "# dataset = client.upload_dataframe(df, \n", + "# name=dataset_name,\n", + "# description=\"A calculator example dataset\",\n", + "# input_keys=[\"input\"],\n", + "# output_keys=[\"output\"],\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "id": "07885b10", + "metadata": { + "tags": [] + }, + "source": [ + "## Running a Chain on a Traced Dataset\n", + "\n", + "Once you have a dataset, you can run a compatible chain or other object over it to see its results. The run traces will automatically be associated with the dataset for easy attribution and analysis.\n", + "\n", + "**First, we'll define the chain we wish to run over the dataset.**\n", + "\n", + "In this case, we're using an agent, but it can be any simple chain." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c2b59104-b90e-466a-b7ea-c5bd0194263b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents import initialize_agent, load_tools\n", + "from langchain.agents import AgentType\n", + "\n", + "llm = ChatOpenAI(temperature=0)\n", + "tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n", + "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)" + ] + }, + { + "cell_type": "markdown", + "id": "84094a4a-1d76-461c-bc37-8c537939b466", + "metadata": {}, + "source": [ + "**Now we're ready to run the chain!**\n", + "\n", + "The docstring below hints ways you can configure the method to run." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee", + "metadata": { + "tags": [] + }, + "outputs": [{ + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m\n", + "\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mconcurrency_level\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m\n", + "Run the chain on a dataset and store traces to the specified session name.\n", + "\n", + "Args:\n", + " dataset_name: Name of the dataset to run the chain on.\n", + " llm_or_chain_factory: Language model or Chain constructor to run\n", + " over the dataset. The Chain constructor is used to permit\n", + " independent calls on each example without carrying over state.\n", + " concurrency_level: The number of async tasks to run concurrently.\n", + " num_repetitions: Number of times to run the model on each example.\n", + " This is useful when testing success rates or generating confidence\n", + " intervals.\n", + " session_name: Name of the session to store the traces in.\n", + " Defaults to {dataset_name}-{chain class name}-{datetime}.\n", + " verbose: Whether to print progress.\n", + "\n", + "Returns:\n", + " A dictionary mapping example ids to the model outputs.\n", + "\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/langchain.py\n", + "\u001b[0;31mType:\u001b[0m method" + ] + }, + "metadata": {}, + "output_type": "display_data" + }], + "source": [ + "?client.arun_on_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6e10f823", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Since chains can be stateful (e.g. they can have memory), we need provide\n", + "# a way to initialize a new chain for each row in the dataset. This is done\n", + "# by passing in a factory function that returns a new chain for each row.\n", + "chain_factory = lambda: initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n", + "\n", + "# If your chain is NOT stateful, your lambda can return the object directly\n", + "# to improve runtime performance. For example:\n", + "# chain_factory = lambda: agent" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", + " warnings.warn(\n", + "Chain failed for example 5523e460-6bb4-4a64-be37-bec0a98699a4. Error: LLMMathChain._evaluate(\"\n", + "(total number of points scored in the 2023 super bowl)**0.23\n", + "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed examples: 2\r" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Chain failed for example f193a3f6-1147-4ce6-a83e-fab1157dc88d. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed examples: 6\r" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Chain failed for example 6d7bbb45-1dc0-4adc-be21-4f76a208a8d2. Error: LLMMathChain._evaluate(\"\n", + "(age ** 0.43)\n", + "\") raised error: 'age'. Please try again with a valid numerical expression\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed examples: 10\r" + ] + } + ], + "source": [ + "chain_results = await client.arun_on_dataset(\n", + " dataset_name=dataset_name,\n", + " llm_or_chain_factory=chain_factory,\n", + " verbose=True\n", + ")\n", + "\n", + "# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n", + "# These are logged as warnings here and captured as errors in the tracing UI." + ] + }, + { + "cell_type": "markdown", + "id": "d2737458-b20c-4288-8790-1f4a8d237b2a", + "metadata": {}, + "source": [ + "## Reviewing the Chain Results\n", + "\n", + "You can review the results of the run in the tracing UI below and navigating to the session \n", + "with the title 'calculator-example-dataset-AgentExecutor-YYYY-MM-DD-HH-MM-SS'" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "136db492-d6ca-4215-96f9-439c23538241", + "metadata": { + "tags": [] + }, + "outputs": [{ + "data": { + "text/html": [ + "LangChain+ Client" + ], + "text/plain": [ + "LangChainPlusClient (API URL: http://localhost:8000)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }], + "source": [ + "# You can navigate to the UI by clicking on the link below\n", + "client" + ] + }, + { + "cell_type": "markdown", + "id": "c70cceb5-aa53-4851-bb12-386f092191f9", + "metadata": {}, + "source": [ + "### Running a Chat Model over a Traced Dataset\n", + "\n", + "We've shown how to run a _chain_ over a dataset, but you can also run an LLM or Chat model over a datasets formed from runs. \n", + "\n", + "First, we'll show an example using a ChatModel. This is useful for things like:\n", + "- Comparing results under different decoding parameters\n", + "- Comparing model providers\n", + "- Testing for regressions in model behavior\n", + "- Running multiple times with a temperature to gauge stability \n", + "\n", + "To speed things up, we'll upload a dataset we've previously captured directly to the tracing service." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "64490d7c-9a18-49ed-a3ac-36049c522cb4", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--two-player-dnd-cc62c3037e2d9250/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0adb751cec11417b88072963325b481d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
generationsmessages
0[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
1[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
2[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
3[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
4[[{'generation_info': None, 'message': {'conte...[{'data': {'content': 'Here is the topic for a...
\n", + "" + ], + "text/plain": [ + " generations \\\n", + "0 [[{'generation_info': None, 'message': {'conte... \n", + "1 [[{'generation_info': None, 'message': {'conte... \n", + "2 [[{'generation_info': None, 'message': {'conte... \n", + "3 [[{'generation_info': None, 'message': {'conte... \n", + "4 [[{'generation_info': None, 'message': {'conte... \n", + "\n", + " messages \n", + "0 [{'data': {'content': 'Here is the topic for a... \n", + "1 [{'data': {'content': 'Here is the topic for a... \n", + "2 [{'data': {'content': 'Here is the topic for a... \n", + "3 [{'data': {'content': 'Here is the topic for a... \n", + "4 [{'data': {'content': 'Here is the topic for a... " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "from langchain.evaluation.loading import load_dataset\n", + "\n", + "chat_dataset = load_dataset(\"two-player-dnd\")\n", + "chat_df = pd.DataFrame(chat_dataset)\n", + "chat_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "348acd86-a927-4d60-8d52-02e64585e4fc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chat_dataset_name = \"two-player-dnd\"\n", + "\n", + "if chat_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n", + " client.upload_dataframe(chat_df, \n", + " name=chat_dataset_name,\n", + " description=\"An example dataset traced from chat models in a multiagent bidding dialogue\",\n", + " input_keys=[\"messages\"],\n", + " output_keys=[\"generations\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "927a43b8-e4f9-4220-b75d-33e310bc318b", + "metadata": {}, + "source": [ + "#### Reviewing behavior with temperature\n", + "\n", + "Here, we will set `num_repetitions > 1` and set the temperature to 0.3 to see the variety of response types for a each example.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "a69dd183-ad5e-473d-b631-db90706e837f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatAnthropic\n", + "\n", + "chat_model = ChatAnthropic(temperature=.3)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "063da2a9-3692-4b7b-8edb-e474824fe416", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed examples: 36\r" + ] + } + ], + "source": [ + "chat_model_results = await client.arun_on_dataset(\n", + " dataset_name=chat_dataset_name,\n", + " llm_or_chain_factory=chat_model,\n", + " concurrency_level=5, # Optional, sets the number of examples to run at a time\n", + " num_repetitions=3,\n", + " verbose=True\n", + ")\n", + "\n", + "# The 'experimental tracing v2' warning is expected, as we are still actively developing the v2 tracing API \n", + "# Since we are running examples concurrently, you may run into some RateLimit warnings from your model\n", + "# provider. In most cases, the tests will still run to completion (the wrappers have backoff)." + ] + }, + { + "cell_type": "markdown", + "id": "de7bfe08-215c-4328-b9b0-631d9a41f0e8", + "metadata": { + "tags": [] + }, + "source": [ + "## Reviewing the Chat Model Results\n", + "\n", + "You can review the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5b7a81f2-d19d-438b-a4bb-5678f746b965", + "metadata": { + "tags": [] + }, + "outputs": [{ + "data": { + "text/html": [ + "LangChain+ Client" + ], + "text/plain": [ + "LangChainPlusClient (API URL: http://localhost:8000)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + }], + "source": [ + "client" + ] + }, + { + "cell_type": "markdown", + "id": "7896cbeb-345f-430b-ab5e-e108973174f8", + "metadata": {}, + "source": [ + "## Running an LLM over a Traced Dataset\n", + "\n", + "You can run an LLM over a dataset in much the same way as the chain and chat models, provided the dataset you've captured is in the appropriate format. We've cached one for you here, but using application-specific traces will be much more useful for your use cases." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d6805d0b-4612-4671-bffb-e6978992bd40", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "\n", + "llm = OpenAI(model_name='text-curie-001', temperature=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5d7cb243-40c3-44dd-8158-a7b910441e9f", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--state-of-the-union-completions-a7eb4af13453cd35/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "189832bd50114f129fb58e590d6e8267", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
generationsground_truthprompt
0[[{'generation_info': {'finish_reason': 'stop'...The pandemic has been punishing. \\n\\nAnd so ma...Putin may circle Kyiv with tanks, but he will ...
1[[]]With a duty to one another to the American peo...Madam Speaker, Madam Vice President, our First...
2[[{'generation_info': {'finish_reason': 'stop'...He thought he could roll into Ukraine and the ...With a duty to one another to the American peo...
3[[{'generation_info': {'finish_reason': 'lengt...With a duty to one another to the American peo...Madam Speaker, Madam Vice President, our First...
4[[]]And the costs and the threats to America and t...Please rise if you are able and show that, Yes...
\n", + "" + ], + "text/plain": [ + " generations \\\n", + "0 [[{'generation_info': {'finish_reason': 'stop'... \n", + "1 [[]] \n", + "2 [[{'generation_info': {'finish_reason': 'stop'... \n", + "3 [[{'generation_info': {'finish_reason': 'lengt... \n", + "4 [[]] \n", + "\n", + " ground_truth \\\n", + "0 The pandemic has been punishing. \\n\\nAnd so ma... \n", + "1 With a duty to one another to the American peo... \n", + "2 He thought he could roll into Ukraine and the ... \n", + "3 With a duty to one another to the American peo... \n", + "4 And the costs and the threats to America and t... \n", + "\n", + " prompt \n", + "0 Putin may circle Kyiv with tanks, but he will ... \n", + "1 Madam Speaker, Madam Vice President, our First... \n", + "2 With a duty to one another to the American peo... \n", + "3 Madam Speaker, Madam Vice President, our First... \n", + "4 Please rise if you are able and show that, Yes... " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "completions_dataset = load_dataset(\"state-of-the-union-completions\")\n", + "completions_df = pd.DataFrame(completions_dataset)\n", + "completions_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c7dcc1b2-7aef-44c0-ba0f-c812279099a5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "completions_dataset_name = \"state-of-the-union-completions\"\n", + "\n", + "if completions_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n", + " client.upload_dataframe(completions_df, \n", + " name=completions_dataset_name,\n", + " description=\"An example dataset traced from completion endpoints over the state of the union address\",\n", + " input_keys=[\"prompt\"],\n", + " output_keys=[\"generations\"],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e946138e-bf7c-43d7-861d-9c5740c933fa", + "metadata": { + "tags": [] + }, + "outputs": [{ + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "55 processed\r" + ] + } + ], + "source": [ + "# We also offer a synchronous method for running examples if a chain or llm's async methods aren't yet implemented\n", + "completions_model_results = client.run_on_dataset(\n", + " dataset_name=completions_dataset_name,\n", + " llm_or_chain_factory=llm,\n", + " num_repetitions=1,\n", + " verbose=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cc86e8e6-cee2-429e-942b-289284d14816", + "metadata": {}, + "source": [ + "## Reviewing the LLM Results\n", + "\n", + "You can once again inspect the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "2bf96f17-74c1-4f7d-8458-ae5ab5c6bd36", + "metadata": { + "tags": [] + }, + "outputs": [{ + "data": { + "text/html": [ + "LangChain+ Client" + ], + "text/plain": [ + "LangChainPlusClient (API URL: http://localhost:8000)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + }], + "source": [ + "client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df80cd88-cd6f-4fdc-965f-f74600e1f286", + "metadata": {}, + "outputs": [], + "source": [] + } ], - "source": [ - "completions_dataset = load_dataset(\"state-of-the-union-completions\")\n", - "completions_df = pd.DataFrame(completions_dataset)\n", - "completions_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "c7dcc1b2-7aef-44c0-ba0f-c812279099a5", "metadata": { - "tags": [] + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, - "outputs": [], - "source": [ - "completions_dataset_name = \"state-of-the-union-completions\"\n", - "\n", - "if completions_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n", - " client.upload_dataframe(completions_df, \n", - " name=completions_dataset_name,\n", - " description=\"An example dataset traced from completion endpoints over the state of the union address\",\n", - " input_keys=[\"prompt\"],\n", - " output_keys=[\"generations\"],\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "e946138e-bf7c-43d7-861d-9c5740c933fa", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "55 processed\r" - ] - } - ], - "source": [ - "# We also offer a synchronous method for running examples if a chain or llm's async methods aren't yet implemented\n", - "completions_model_results = client.run_on_dataset(\n", - " dataset_name=completions_dataset_name,\n", - " llm_or_chain_factory=llm,\n", - " num_repetitions=1,\n", - " verbose=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "cc86e8e6-cee2-429e-942b-289284d14816", - "metadata": {}, - "source": [ - "## Reviewing the LLM Results\n", - "\n", - "You can once again inspect the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session." - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "2bf96f17-74c1-4f7d-8458-ae5ab5c6bd36", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "LangChain+ Client" - ], - "text/plain": [ - "LangChainPlusClient (API URL: http://localhost:8000)" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "df80cd88-cd6f-4fdc-965f-f74600e1f286", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/tests/unit_tests/callbacks/tracers/test_base_tracer.py new file mode 100644 index 0000000000..4ff2e342c5 --- /dev/null +++ b/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -0,0 +1,464 @@ +"""Test Tracer classes.""" +from __future__ import annotations + +from datetime import datetime +from typing import List +from uuid import uuid4 + +import pytest +from freezegun import freeze_time + +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.tracers.base import BaseTracer, TracerException +from langchain.callbacks.tracers.schemas import Run +from langchain.schema import LLMResult + + +class FakeTracer(BaseTracer): + """Fake tracer that records LangChain execution.""" + + def __init__(self) -> None: + """Initialize the tracer.""" + super().__init__() + self.runs: List[Run] = [] + + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + self.runs.append(run) + + +@freeze_time("2023-01-01") +def test_tracer_llm_run() -> None: + """Test tracer on an LLM run.""" + uuid = uuid4() + compare_run = Run( + id=uuid, + parent_run_id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "llm"}, + inputs={"prompts": []}, + outputs=LLMResult(generations=[[]]), + error=None, + run_type="llm", + ) + tracer = FakeTracer() + + tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_chat_model_run() -> None: + """Test tracer on a Chat Model run.""" + uuid = uuid4() + compare_run = Run( + id=str(uuid), + name="chat_model", + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chat_model"}, + inputs=dict(prompts=[""]), + outputs=LLMResult(generations=[[]]), + error=None, + run_type="llm", + ) + tracer = FakeTracer() + manager = CallbackManager(handlers=[tracer]) + run_manager = manager.on_chat_model_start( + serialized={"name": "chat_model"}, messages=[[]], run_id=uuid + ) + run_manager.on_llm_end(response=LLMResult(generations=[[]])) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_errors_no_start() -> None: + """Test tracer on an LLM run without a start.""" + tracer = FakeTracer() + + with pytest.raises(TracerException): + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) + + +@freeze_time("2023-01-01") +def test_tracer_multiple_llm_runs() -> None: + """Test the tracer with multiple runs.""" + uuid = uuid4() + compare_run = Run( + id=uuid, + name="llm", + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "llm"}, + inputs=dict(prompts=[]), + outputs=LLMResult(generations=[[]]), + error=None, + run_type="llm", + ) + tracer = FakeTracer() + + num_runs = 10 + for _ in range(num_runs): + tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) + + assert tracer.runs == [compare_run] * num_runs + + +@freeze_time("2023-01-01") +def test_tracer_chain_run() -> None: + """Test tracer on a Chain run.""" + uuid = uuid4() + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chain"}, + inputs={}, + outputs={}, + error=None, + run_type="chain", + ) + tracer = FakeTracer() + + tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) + tracer.on_chain_end(outputs={}, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run() -> None: + """Test tracer on a Tool run.""" + uuid = uuid4() + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "tool"}, + inputs={"input": "test"}, + outputs={"output": "test"}, + error=None, + run_type="tool", + ) + tracer = FakeTracer() + tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) + tracer.on_tool_end("test", run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_run() -> None: + """Test tracer on a nested run.""" + tracer = FakeTracer() + + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + for _ in range(10): + tracer.on_chain_start( + serialized={"name": "chain"}, inputs={}, run_id=chain_uuid + ) + tracer.on_tool_start( + serialized={"name": "tool"}, + input_str="test", + run_id=tool_uuid, + parent_run_id=chain_uuid, + ) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid1, + parent_run_id=tool_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_tool_end("test", run_id=tool_uuid) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid2, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_chain_end(outputs={}, run_id=chain_uuid) + + compare_run = Run( + id=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=4, + serialized={"name": "chain"}, + inputs={}, + outputs={}, + run_type="chain", + child_runs=[ + Run( + id=tool_uuid, + parent_run_id=chain_uuid, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + child_execution_order=3, + serialized={"name": "tool"}, + inputs=dict(input="test"), + outputs=dict(output="test"), + error=None, + run_type="tool", + child_runs=[ + Run( + id=str(llm_uuid1), + parent_run_id=str(tool_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + child_execution_order=3, + serialized={"name": "llm"}, + inputs=dict(prompts=[]), + outputs=LLMResult(generations=[[]]), + run_type="llm", + ) + ], + ), + Run( + id=str(llm_uuid2), + parent_run_id=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + child_execution_order=4, + serialized={"name": "llm"}, + inputs=dict(prompts=[]), + outputs=LLMResult(generations=[[]]), + run_type="llm", + ), + ], + ) + assert tracer.runs[0] == compare_run + assert tracer.runs == [compare_run] * 10 + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error() -> None: + """Test tracer on an LLM run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "llm"}, + inputs=dict(prompts=[]), + outputs=None, + error=repr(exception), + run_type="llm", + ) + tracer = FakeTracer() + + tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_chain_run_on_error() -> None: + """Test tracer on a Chain run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chain"}, + inputs={}, + outputs=None, + error=repr(exception), + run_type="chain", + ) + tracer = FakeTracer() + + tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) + tracer.on_chain_error(exception, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run_on_error() -> None: + """Test tracer on a Tool run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "tool"}, + inputs=dict(input="test"), + outputs=None, + action="{'name': 'tool'}", + error=repr(exception), + run_type="tool", + ) + tracer = FakeTracer() + + tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) + tracer.on_tool_error(exception, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_runs_on_error() -> None: + """Test tracer on a nested run with an error.""" + exception = Exception("test") + + tracer = FakeTracer() + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + llm_uuid3 = uuid4() + + for _ in range(3): + tracer.on_chain_start( + serialized={"name": "chain"}, inputs={}, run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid1, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid2, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_tool_start( + serialized={"name": "tool"}, + input_str="test", + run_id=tool_uuid, + parent_run_id=chain_uuid, + ) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid3, + parent_run_id=tool_uuid, + ) + tracer.on_llm_error(exception, run_id=llm_uuid3) + tracer.on_tool_error(exception, run_id=tool_uuid) + tracer.on_chain_error(exception, run_id=chain_uuid) + + compare_run = Run( + id=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=5, + serialized={"name": "chain"}, + error=repr(exception), + inputs={}, + outputs=None, + run_type="chain", + child_runs=[ + Run( + id=str(llm_uuid1), + parent_run_id=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + child_execution_order=2, + serialized={"name": "llm"}, + error=None, + inputs=dict(prompts=[]), + outputs=LLMResult(generations=[[]], llm_output=None), + run_type="llm", + ), + Run( + id=str(llm_uuid2), + parent_run_id=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + child_execution_order=3, + serialized={"name": "llm"}, + error=None, + inputs=dict(prompts=[]), + outputs=LLMResult(generations=[[]], llm_output=None), + run_type="llm", + ), + Run( + id=str(tool_uuid), + parent_run_id=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + child_execution_order=5, + serialized={"name": "tool"}, + error=repr(exception), + inputs=dict(input="test"), + outputs=None, + action="{'name': 'tool'}", + child_runs=[ + Run( + id=str(llm_uuid3), + parent_run_id=str(tool_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=5, + child_execution_order=5, + serialized={"name": "llm"}, + error=repr(exception), + inputs=dict(prompts=[]), + outputs=None, + run_type="llm", + ) + ], + run_type="tool", + ), + ], + ) + assert tracer.runs == [compare_run] * 3 diff --git a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py new file mode 100644 index 0000000000..ab655ac631 --- /dev/null +++ b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -0,0 +1,676 @@ +"""Test Tracer classes.""" +from __future__ import annotations + +from datetime import datetime +from typing import List, Optional, Union +from uuid import uuid4 + +import pytest +from freezegun import freeze_time + +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.tracers.base import BaseTracer, TracerException +from langchain.callbacks.tracers.langchain_v1 import ( + ChainRun, + LangChainTracerV1, + LLMRun, + ToolRun, + TracerSessionV1, +) +from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base +from langchain.schema import LLMResult + +TEST_SESSION_ID = 2023 + + +def load_session(session_name: str) -> TracerSessionV1: + """Load a tracing session.""" + return TracerSessionV1( + id=TEST_SESSION_ID, name=session_name, start_time=datetime.utcnow() + ) + + +def new_session(name: Optional[str] = None) -> TracerSessionV1: + """Create a new tracing session.""" + return TracerSessionV1( + id=TEST_SESSION_ID, name=name or "default", start_time=datetime.utcnow() + ) + + +def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1: + """Persist a tracing session.""" + return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID}) + + +def load_default_session() -> TracerSessionV1: + """Load a tracing session.""" + return TracerSessionV1( + id=TEST_SESSION_ID, name="default", start_time=datetime.utcnow() + ) + + +@pytest.fixture +def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1: + monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") + monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") + monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") + tracer = LangChainTracerV1() + return tracer + + +class FakeTracer(BaseTracer): + """Fake tracer that records LangChain execution.""" + + def __init__(self) -> None: + """Initialize the tracer.""" + super().__init__() + self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] + + def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, Run): + with pytest.MonkeyPatch().context() as m: + m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") + m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") + m.setenv("LANGCHAIN_API_KEY", "foo") + tracer = LangChainTracerV1() + tracer.load_default_session = load_default_session # type: ignore + run = tracer._convert_to_v1_run(run) + self.runs.append(run) + + def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1: + """Persist a tracing session.""" + return _persist_session(session) + + def new_session(self, name: Optional[str] = None) -> TracerSessionV1: + """Create a new tracing session.""" + return new_session(name) + + def load_session(self, session_name: str) -> TracerSessionV1: + """Load a tracing session.""" + return load_session(session_name) + + def load_default_session(self) -> TracerSessionV1: + """Load a tracing session.""" + return load_default_session() + + +@freeze_time("2023-01-01") +def test_tracer_llm_run() -> None: + """Test tracer on an LLM run.""" + uuid = uuid4() + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "llm"}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_chat_model_run() -> None: + """Test tracer on a Chat Model run.""" + uuid = uuid4() + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chat_model"}, + prompts=[""], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + manager = CallbackManager(handlers=[tracer]) + run_manager = manager.on_chat_model_start( + serialized={"name": "chat_model"}, messages=[[]], run_id=uuid + ) + run_manager.on_llm_end(response=LLMResult(generations=[[]])) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_errors_no_start() -> None: + """Test tracer on an LLM run without a start.""" + tracer = FakeTracer() + + tracer.new_session() + with pytest.raises(TracerException): + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) + + +@freeze_time("2023-01-01") +def test_tracer_multiple_llm_runs() -> None: + """Test the tracer with multiple runs.""" + uuid = uuid4() + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "llm"}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + num_runs = 10 + for _ in range(num_runs): + tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) + + assert tracer.runs == [compare_run] * num_runs + + +@freeze_time("2023-01-01") +def test_tracer_chain_run() -> None: + """Test tracer on a Chain run.""" + uuid = uuid4() + compare_run = ChainRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chain"}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) + tracer.on_chain_end(outputs={}, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run() -> None: + """Test tracer on a Tool run.""" + uuid = uuid4() + compare_run = ToolRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "tool"}, + tool_input="test", + output="test", + action="{'name': 'tool'}", + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) + tracer.on_tool_end("test", run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_run() -> None: + """Test tracer on a nested run.""" + tracer = FakeTracer() + tracer.new_session() + + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + for _ in range(10): + tracer.on_chain_start( + serialized={"name": "chain"}, inputs={}, run_id=chain_uuid + ) + tracer.on_tool_start( + serialized={"name": "tool"}, + input_str="test", + run_id=tool_uuid, + parent_run_id=chain_uuid, + ) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid1, + parent_run_id=tool_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_tool_end("test", run_id=tool_uuid) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid2, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_chain_end(outputs={}, run_id=chain_uuid) + + compare_run = ChainRun( + uuid=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=4, + serialized={"name": "chain"}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + child_chain_runs=[], + child_tool_runs=[ + ToolRun( + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + child_execution_order=3, + serialized={"name": "tool"}, + tool_input="test", + output="test", + action="{'name': 'tool'}", + session_id=TEST_SESSION_ID, + error=None, + child_chain_runs=[], + child_tool_runs=[], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid1), + parent_uuid=str(tool_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + child_execution_order=3, + serialized={"name": "llm"}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ) + ], + ), + ], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + child_execution_order=4, + serialized={"name": "llm"}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ), + ], + ) + assert tracer.runs[0] == compare_run + assert tracer.runs == [compare_run] * 10 + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error() -> None: + """Test tracer on an LLM run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "llm"}, + prompts=[], + response=None, + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_chain_run_on_error() -> None: + """Test tracer on a Chain run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = ChainRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "chain"}, + inputs={}, + outputs=None, + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) + tracer.on_chain_error(exception, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run_on_error() -> None: + """Test tracer on a Tool run with an error.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = ToolRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={"name": "tool"}, + tool_input="test", + output=None, + action="{'name': 'tool'}", + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) + tracer.on_tool_error(exception, run_id=uuid) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_runs_on_error() -> None: + """Test tracer on a nested run with an error.""" + exception = Exception("test") + + tracer = FakeTracer() + tracer.new_session() + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + llm_uuid3 = uuid4() + + for _ in range(3): + tracer.on_chain_start( + serialized={"name": "chain"}, inputs={}, run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid1, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid2, + parent_run_id=chain_uuid, + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_tool_start( + serialized={"name": "tool"}, + input_str="test", + run_id=tool_uuid, + parent_run_id=chain_uuid, + ) + tracer.on_llm_start( + serialized={"name": "llm"}, + prompts=[], + run_id=llm_uuid3, + parent_run_id=tool_uuid, + ) + tracer.on_llm_error(exception, run_id=llm_uuid3) + tracer.on_tool_error(exception, run_id=tool_uuid) + tracer.on_chain_error(exception, run_id=chain_uuid) + + compare_run = ChainRun( + uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=5, + serialized={"name": "chain"}, + session_id=TEST_SESSION_ID, + error=repr(exception), + inputs={}, + outputs=None, + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid1), + parent_uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + child_execution_order=2, + serialized={"name": "llm"}, + session_id=TEST_SESSION_ID, + error=None, + prompts=[], + response=LLMResult(generations=[[]], llm_output=None), + ), + LLMRun( + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + child_execution_order=3, + serialized={"name": "llm"}, + session_id=TEST_SESSION_ID, + error=None, + prompts=[], + response=LLMResult(generations=[[]], llm_output=None), + ), + ], + child_chain_runs=[], + child_tool_runs=[ + ToolRun( + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + child_execution_order=5, + serialized={"name": "tool"}, + session_id=TEST_SESSION_ID, + error=repr(exception), + tool_input="test", + output=None, + action="{'name': 'tool'}", + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid3), + parent_uuid=str(tool_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=5, + child_execution_order=5, + serialized={"name": "llm"}, + session_id=TEST_SESSION_ID, + error=repr(exception), + prompts=[], + response=None, + ) + ], + child_chain_runs=[], + child_tool_runs=[], + ), + ], + ) + assert tracer.runs == [compare_run] * 3 + + +@pytest.fixture +def sample_tracer_session_v1() -> TracerSessionV1: + return TracerSessionV1(id=2, name="Sample session") + + +@freeze_time("2023-01-01") +def test_convert_run( + lang_chain_tracer_v1: LangChainTracerV1, + sample_tracer_session_v1: TracerSessionV1, +) -> None: + """Test converting a run to a V1 run.""" + llm_run = Run( + id="57a08cc4-73d2-4236-8370-549099d07fad", + name="llm_run", + execution_order=1, + child_execution_order=1, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + session_id=TEST_SESSION_ID, + inputs={"prompts": []}, + outputs=LLMResult(generations=[[]]).dict(), + serialized={}, + extra={}, + run_type=RunTypeEnum.llm, + ) + chain_run = Run( + id="57a08cc4-73d2-4236-8371-549099d07fad", + name="chain_run", + execution_order=1, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + child_execution_order=1, + serialized={}, + inputs={}, + outputs={}, + child_runs=[llm_run], + extra={}, + run_type=RunTypeEnum.chain, + ) + + tool_run = Run( + id="57a08cc4-73d2-4236-8372-549099d07fad", + name="tool_run", + execution_order=1, + child_execution_order=1, + inputs={"input": "test"}, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + outputs=None, + serialized={}, + child_runs=[], + extra={}, + run_type=RunTypeEnum.tool, + ) + + expected_llm_run = LLMRun( + uuid="57a08cc4-73d2-4236-8370-549099d07fad", + name="llm_run", + execution_order=1, + child_execution_order=1, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + session_id=2, + prompts=[], + response=LLMResult(generations=[[]]), + serialized={}, + extra={}, + ) + + expected_chain_run = ChainRun( + uuid="57a08cc4-73d2-4236-8371-549099d07fad", + name="chain_run", + execution_order=1, + child_execution_order=1, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + session_id=2, + serialized={}, + inputs={}, + outputs={}, + child_llm_runs=[expected_llm_run], + child_chain_runs=[], + child_tool_runs=[], + extra={}, + ) + expected_tool_run = ToolRun( + uuid="57a08cc4-73d2-4236-8372-549099d07fad", + name="tool_run", + execution_order=1, + child_execution_order=1, + session_id=2, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + tool_input="test", + action="{}", + serialized={}, + child_llm_runs=[], + child_chain_runs=[], + child_tool_runs=[], + extra={}, + ) + lang_chain_tracer_v1.session = sample_tracer_session_v1 + converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run) + converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run) + converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run) + + assert isinstance(converted_llm_run, LLMRun) + assert isinstance(converted_chain_run, ChainRun) + assert isinstance(converted_tool_run, ToolRun) + assert converted_llm_run == expected_llm_run + assert converted_tool_run == expected_tool_run + assert converted_chain_run == expected_chain_run diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 5c0c4b1138..055ac64067 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -3,621 +3,90 @@ from __future__ import annotations import json from datetime import datetime -from typing import List, Tuple, Union -from unittest.mock import Mock, patch +from typing import Tuple +from unittest.mock import patch from uuid import UUID, uuid4 import pytest from freezegun import freeze_time -from langchain.callbacks.manager import CallbackManager -from langchain.callbacks.tracers.base import ( - BaseTracer, - ChainRun, - LLMRun, - ToolRun, - TracerException, - TracerSession, -) -from langchain.callbacks.tracers.langchain import LangChainTracerV2 -from langchain.callbacks.tracers.schemas import ( - RunCreate, - TracerSessionBase, - TracerSessionV2, - TracerSessionV2Create, -) +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession from langchain.schema import LLMResult -TEST_SESSION_ID = 2023 - - -def load_session(session_name: str) -> TracerSession: - """Load a tracing session.""" - return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) - - -def _persist_session(session: TracerSessionBase) -> TracerSession: - """Persist a tracing session.""" - return TracerSession(id=TEST_SESSION_ID, **session.dict()) - - -def load_default_session() -> TracerSession: - """Load a tracing session.""" - return TracerSession(id=1, name="default", start_time=datetime.utcnow()) - - -class FakeTracer(BaseTracer): - """Fake tracer that records LangChain execution.""" - - def __init__(self) -> None: - """Initialize the tracer.""" - super().__init__() - self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] - - def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: - """Persist a run.""" - self.runs.append(run) - - def _persist_session(self, session: TracerSessionBase) -> TracerSession: - """Persist a tracing session.""" - return _persist_session(session) - - def load_session(self, session_name: str) -> TracerSession: - """Load a tracing session.""" - return load_session(session_name) - - def load_default_session(self) -> TracerSession: - """Load a tracing session.""" - return load_default_session() - - -@freeze_time("2023-01-01") -def test_tracer_llm_run() -> None: - """Test tracer on an LLM run.""" - uuid = uuid4() - compare_run = LLMRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - error=None, - ) - tracer = FakeTracer() - - tracer.new_session() - tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_chat_model_run() -> None: - """Test tracer on a Chat Model run.""" - uuid = uuid4() - compare_run = LLMRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - prompts=[""], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - error=None, - ) - tracer = FakeTracer() - - tracer.new_session() - manager = CallbackManager(handlers=[tracer]) - run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid) - run_manager.on_llm_end(response=LLMResult(generations=[[]])) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_llm_run_errors_no_start() -> None: - """Test tracer on an LLM run without a start.""" - tracer = FakeTracer() - - tracer.new_session() - with pytest.raises(TracerException): - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) - - -@freeze_time("2023-01-01") -def test_tracer_multiple_llm_runs() -> None: - """Test the tracer with multiple runs.""" - uuid = uuid4() - compare_run = LLMRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - error=None, - ) - tracer = FakeTracer() - - tracer.new_session() - num_runs = 10 - for _ in range(num_runs): - tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) - - assert tracer.runs == [compare_run] * num_runs - - -@freeze_time("2023-01-01") -def test_tracer_chain_run() -> None: - """Test tracer on a Chain run.""" - uuid = uuid4() - compare_run = ChainRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - inputs={}, - outputs={}, - session_id=TEST_SESSION_ID, - error=None, - ) - tracer = FakeTracer() - - tracer.new_session() - tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid) - tracer.on_chain_end(outputs={}, run_id=uuid) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_tool_run() -> None: - """Test tracer on a Tool run.""" - uuid = uuid4() - compare_run = ToolRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - tool_input="test", - output="test", - action="{}", - session_id=TEST_SESSION_ID, - error=None, - ) - tracer = FakeTracer() - - tracer.new_session() - tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid) - tracer.on_tool_end("test", run_id=uuid) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_nested_run() -> None: - """Test tracer on a nested run.""" - tracer = FakeTracer() - tracer.new_session() - - chain_uuid = uuid4() - tool_uuid = uuid4() - llm_uuid1 = uuid4() - llm_uuid2 = uuid4() - for _ in range(10): - tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) - tracer.on_tool_start( - serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid - ) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid - ) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) - tracer.on_tool_end("test", run_id=tool_uuid) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid - ) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) - tracer.on_chain_end(outputs={}, run_id=chain_uuid) - - compare_run = ChainRun( - uuid=str(chain_uuid), - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=4, - serialized={}, - inputs={}, - outputs={}, - session_id=TEST_SESSION_ID, - child_chain_runs=[], - child_tool_runs=[ - ToolRun( - uuid=str(tool_uuid), - parent_uuid=str(chain_uuid), - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=2, - child_execution_order=3, - serialized={}, - tool_input="test", - output="test", - action="{}", - session_id=TEST_SESSION_ID, - error=None, - child_chain_runs=[], - child_tool_runs=[], - child_llm_runs=[ - LLMRun( - uuid=str(llm_uuid1), - parent_uuid=str(tool_uuid), - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=3, - child_execution_order=3, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - ) - ], - ), - ], - child_llm_runs=[ - LLMRun( - uuid=str(llm_uuid2), - parent_uuid=str(chain_uuid), - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=4, - child_execution_order=4, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - ), - ], - ) - assert tracer.runs == [compare_run] * 10 - - -@freeze_time("2023-01-01") -def test_tracer_llm_run_on_error() -> None: - """Test tracer on an LLM run with an error.""" - exception = Exception("test") - uuid = uuid4() - - compare_run = LLMRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - prompts=[], - response=None, - session_id=TEST_SESSION_ID, - error=repr(exception), - ) - tracer = FakeTracer() - - tracer.new_session() - tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) - tracer.on_llm_error(exception, run_id=uuid) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_chain_run_on_error() -> None: - """Test tracer on a Chain run with an error.""" - exception = Exception("test") - uuid = uuid4() - - compare_run = ChainRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - inputs={}, - outputs=None, - session_id=TEST_SESSION_ID, - error=repr(exception), - ) - tracer = FakeTracer() - - tracer.new_session() - tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid) - tracer.on_chain_error(exception, run_id=uuid) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_tool_run_on_error() -> None: - """Test tracer on a Tool run with an error.""" - exception = Exception("test") - uuid = uuid4() - - compare_run = ToolRun( - uuid=str(uuid), - parent_uuid=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=1, - serialized={}, - tool_input="test", - output=None, - action="{}", - session_id=TEST_SESSION_ID, - error=repr(exception), - ) - tracer = FakeTracer() - - tracer.new_session() - tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid) - tracer.on_tool_error(exception, run_id=uuid) - assert tracer.runs == [compare_run] - - -@freeze_time("2023-01-01") -def test_tracer_nested_runs_on_error() -> None: - """Test tracer on a nested run with an error.""" - exception = Exception("test") - - tracer = FakeTracer() - tracer.new_session() - chain_uuid = uuid4() - tool_uuid = uuid4() - llm_uuid1 = uuid4() - llm_uuid2 = uuid4() - llm_uuid3 = uuid4() - - for _ in range(3): - tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid - ) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid - ) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) - tracer.on_tool_start( - serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid - ) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid - ) - tracer.on_llm_error(exception, run_id=llm_uuid3) - tracer.on_tool_error(exception, run_id=tool_uuid) - tracer.on_chain_error(exception, run_id=chain_uuid) - - compare_run = ChainRun( - uuid=str(chain_uuid), - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=5, - serialized={}, - session_id=TEST_SESSION_ID, - error=repr(exception), - inputs={}, - outputs=None, - child_llm_runs=[ - LLMRun( - uuid=str(llm_uuid1), - parent_uuid=str(chain_uuid), - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=2, - child_execution_order=2, - serialized={}, - session_id=TEST_SESSION_ID, - error=None, - prompts=[], - response=LLMResult(generations=[[]], llm_output=None), - ), - LLMRun( - uuid=str(llm_uuid2), - parent_uuid=str(chain_uuid), - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=3, - child_execution_order=3, - serialized={}, - session_id=TEST_SESSION_ID, - error=None, - prompts=[], - response=LLMResult(generations=[[]], llm_output=None), - ), - ], - child_chain_runs=[], - child_tool_runs=[ - ToolRun( - uuid=str(tool_uuid), - parent_uuid=str(chain_uuid), - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=4, - child_execution_order=5, - serialized={}, - session_id=TEST_SESSION_ID, - error=repr(exception), - tool_input="test", - output=None, - action="{}", - child_llm_runs=[ - LLMRun( - uuid=str(llm_uuid3), - parent_uuid=str(tool_uuid), - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=5, - child_execution_order=5, - serialized={}, - session_id=TEST_SESSION_ID, - error=repr(exception), - prompts=[], - response=None, - ) - ], - child_chain_runs=[], - child_tool_runs=[], - ), - ], - ) - - assert tracer.runs == [compare_run] * 3 - - _SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a") _TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad") @pytest.fixture -def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2: +def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer: monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") - tracer = LangChainTracerV2() + tracer = LangChainTracer() return tracer -# Mock a sample TracerSessionV2 object +# Mock a sample TracerSession object @pytest.fixture -def sample_tracer_session_v2() -> TracerSessionV2: - return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID) +def sample_tracer_session_v2() -> TracerSession: + return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID) -# Mock a sample LLMRun, ChainRun, and ToolRun objects +@freeze_time("2023-01-01") @pytest.fixture -def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]: - llm_run = LLMRun( - uuid="57a08cc4-73d2-4236-8370-549099d07fad", +def sample_runs() -> Tuple[Run, Run, Run]: + llm_run = Run( + id="57a08cc4-73d2-4236-8370-549099d07fad", name="llm_run", execution_order=1, child_execution_order=1, + parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad", + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), session_id=1, - prompts=[], - response=LLMResult(generations=[[]]), + inputs={"prompts": []}, + outputs=LLMResult(generations=[[]]).dict(), serialized={}, extra={}, + run_type=RunTypeEnum.llm, ) - chain_run = ChainRun( - uuid="57a08cc4-73d2-4236-8371-549099d07fad", + chain_run = Run( + id="57a08cc4-73d2-4236-8371-549099d07fad", name="chain_run", execution_order=1, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), child_execution_order=1, - session_id=1, serialized={}, inputs={}, - outputs=None, - child_llm_runs=[llm_run], - child_chain_runs=[], - child_tool_runs=[], + outputs={}, + child_runs=[llm_run], extra={}, + run_type=RunTypeEnum.chain, ) - tool_run = ToolRun( - uuid="57a08cc4-73d2-4236-8372-549099d07fad", + + tool_run = Run( + id="57a08cc4-73d2-4236-8372-549099d07fad", name="tool_run", execution_order=1, child_execution_order=1, - session_id=1, - tool_input="test", - action="{}", + inputs={"input": "test"}, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + outputs=None, serialized={}, - child_llm_runs=[], - child_chain_runs=[], - child_tool_runs=[], + child_runs=[], extra={}, + run_type=RunTypeEnum.tool, ) return llm_run, chain_run, tool_run -def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None: - expected = {"tenant_id": "test-tenant-id"} - result = lang_chain_tracer_v2._get_default_query_params() - assert result == expected - - -@patch("langchain.callbacks.tracers.langchain.requests.get") -def test_load_session( - mock_requests_get: Mock, - lang_chain_tracer_v2: LangChainTracerV2, - sample_tracer_session_v2: TracerSessionV2, -) -> None: - """Test that load_session method returns a TracerSessionV2 object.""" - mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()] - result = lang_chain_tracer_v2.load_session("test-session-name") - mock_requests_get.assert_called_with( - "http://test-endpoint.com/sessions", - headers={"Content-Type": "application/json", "x-api-key": "foo"}, - params={"tenant_id": "test-tenant-id", "name": "test-session-name"}, - ) - assert result == sample_tracer_session_v2 - - -def test_convert_run( - lang_chain_tracer_v2: LangChainTracerV2, - sample_tracer_session_v2: TracerSessionV2, - sample_runs: Tuple[LLMRun, ChainRun, ToolRun], -) -> None: - llm_run, chain_run, tool_run = sample_runs - lang_chain_tracer_v2.session = sample_tracer_session_v2 - converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run) - converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run) - converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run) - - assert isinstance(converted_llm_run, RunCreate) - assert isinstance(converted_chain_run, RunCreate) - assert isinstance(converted_tool_run, RunCreate) - - def test_persist_run( - lang_chain_tracer_v2: LangChainTracerV2, - sample_tracer_session_v2: TracerSessionV2, - sample_runs: Tuple[LLMRun, ChainRun, ToolRun], + lang_chain_tracer_v2: LangChainTracer, + sample_tracer_session_v2: TracerSession, + sample_runs: Tuple[Run, Run, Run], ) -> None: """Test that persist_run method calls requests.post once per method call.""" with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( @@ -625,25 +94,25 @@ def test_persist_run( ) as get: post.return_value.raise_for_status.return_value = None lang_chain_tracer_v2.session = sample_tracer_session_v2 - llm_run, chain_run, tool_run = sample_runs - lang_chain_tracer_v2._persist_run(llm_run) - lang_chain_tracer_v2._persist_run(chain_run) - lang_chain_tracer_v2._persist_run(tool_run) + for run in sample_runs: + lang_chain_tracer_v2.run_map[str(run.id)] = run + for run in sample_runs: + lang_chain_tracer_v2._end_trace(run) assert post.call_count == 3 assert get.call_count == 0 def test_persist_run_with_example_id( - lang_chain_tracer_v2: LangChainTracerV2, - sample_tracer_session_v2: TracerSessionV2, - sample_runs: Tuple[LLMRun, ChainRun, ToolRun], + lang_chain_tracer_v2: LangChainTracer, + sample_tracer_session_v2: TracerSession, + sample_runs: Tuple[Run, Run, Run], ) -> None: """Test the example ID is assigned only to the parent run and not the children.""" example_id = uuid4() llm_run, chain_run, tool_run = sample_runs - chain_run.child_tool_runs = [tool_run] - tool_run.child_llm_runs = [llm_run] + chain_run.child_runs = [tool_run] + tool_run.child_runs = [llm_run] with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( "langchain.callbacks.tracers.langchain.requests.get" ) as get: @@ -652,55 +121,14 @@ def test_persist_run_with_example_id( lang_chain_tracer_v2.example_id = example_id lang_chain_tracer_v2._persist_run(chain_run) - assert post.call_count == 1 + assert post.call_count == 3 assert get.call_count == 0 - posted_data = json.loads(post.call_args[1]["data"]) - assert posted_data["id"] == chain_run.uuid - assert posted_data["reference_example_id"] == str(example_id) - - def assert_child_run_no_example_id(run: dict) -> None: - assert not run.get("reference_example_id") - for child_run in run.get("child_runs", []): - assert_child_run_no_example_id(child_run) - - for child_run in posted_data["child_runs"]: - assert_child_run_no_example_id(child_run) - - -def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None: - """Test creating the 'SessionCreate' object.""" - lang_chain_tracer_v2.tenant_id = str(_TENANT_ID) - session_create = lang_chain_tracer_v2._get_session_create(name="test") - assert isinstance(session_create, TracerSessionV2Create) - assert session_create.name == "test" - assert session_create.tenant_id == _TENANT_ID - - -@patch("langchain.callbacks.tracers.langchain.requests.post") -def test_persist_session( - mock_requests_post: Mock, - lang_chain_tracer_v2: LangChainTracerV2, - sample_tracer_session_v2: TracerSessionV2, -) -> None: - """Test persist_session returns a TracerSessionV2 with the updated ID.""" - session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict()) - new_id = str(uuid4()) - mock_requests_post.return_value.json.return_value = {"id": new_id} - result = lang_chain_tracer_v2._persist_session(session_create) - assert isinstance(result, TracerSessionV2) - res = sample_tracer_session_v2.dict() - res["id"] = UUID(new_id) - assert result.dict() == res - - -@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session") -def test_load_default_session( - mock_load_session: Mock, - lang_chain_tracer_v2: LangChainTracerV2, - sample_tracer_session_v2: TracerSessionV2, -) -> None: - """Test load_default_session attempts to load with the default name.""" - mock_load_session.return_value = sample_tracer_session_v2 - result = lang_chain_tracer_v2.load_default_session() - assert result == sample_tracer_session_v2 - mock_load_session.assert_called_with("default") + posted_data = [ + json.loads(call_args[1]["data"]) for call_args in post.call_args_list + ] + assert posted_data[0]["id"] == str(chain_run.id) + assert posted_data[0]["reference_example_id"] == str(example_id) + assert posted_data[1]["id"] == str(tool_run.id) + assert not posted_data[1].get("reference_example_id") + assert posted_data[2]["id"] == str(llm_run.id) + assert not posted_data[2].get("reference_example_id") diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py index 731c4d694d..85fcf9436a 100644 --- a/tests/unit_tests/client/test_langchain.py +++ b/tests/unit_tests/client/test_langchain.py @@ -8,8 +8,8 @@ from unittest import mock import pytest from langchain.base_language import BaseLanguageModel -from langchain.callbacks.tracers.langchain import LangChainTracerV2 -from langchain.callbacks.tracers.schemas import TracerSessionV2 +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.schemas import TracerSession from langchain.chains.base import Chain from langchain.client.langchain import ( LangChainPlusClient, @@ -196,10 +196,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: {"result": f"Result for example {example.id}"} for _ in range(n_repetitions) ] - def mock_load_session( - self: Any, name: str, *args: Any, **kwargs: Any - ) -> TracerSessionV2: - return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4()) + def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession: + return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4()) with mock.patch.object( LangChainPlusClient, "read_dataset", new=mock_read_dataset @@ -208,7 +206,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: ), mock.patch.object( LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain ), mock.patch.object( - LangChainTracerV2, "load_session", new=mock_load_session + LangChainTracer, "ensure_session", new=mock_ensure_session ): monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID) client = LangChainPlusClient(