add tracing support to langchain (#741)

* add implementations of `BaseCallbackHandler` to support tracing:
`SharedTracer` which is thread-safe and `Tracer` which is not and is
meant to be used locally.
* Tracers persist runs to locally running `langchain-server`

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
ankush/async-llmchain
Ankush Gola 1 year ago committed by GitHub
parent 7f76a1189c
commit 57609845df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,11 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks import set_default_callback_manager, set_handler
from langchain.callbacks import (
set_default_callback_manager,
set_handler,
set_tracing_callback_manager,
)
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -68,4 +72,5 @@ __all__ = [
"QAWithSourcesChain",
"PALChain",
"set_handler",
"set_tracing_callback_manager",
]

@ -284,7 +284,7 @@ class AgentExecutor(Chain, BaseModel):
observation = tool.func(output.tool_input)
color = color_mapping[output.tool]
return_direct = tool.return_direct
except Exception as e:
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_tool_error(e, verbose=self.verbose)
raise e
else:

@ -1,11 +1,13 @@
"""Callback handlers that allow listening to events in LangChain."""
import os
from contextlib import contextmanager
from typing import Generator
from typing import Generator, Optional
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers import SharedLangChainTracer
def get_callback_manager() -> BaseCallbackManager:
@ -21,7 +23,31 @@ def set_handler(handler: BaseCallbackHandler) -> None:
def set_default_callback_manager() -> None:
"""Set default callback manager."""
set_handler(StdOutCallbackHandler())
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
if default_handler == "stdout":
set_handler(StdOutCallbackHandler())
elif default_handler == "langchain":
session = os.environ.get("LANGCHAIN_SESSION")
set_tracing_callback_manager(session)
else:
raise ValueError(
f"LANGCHAIN_HANDLER should be one of `stdout` "
f"or `langchain`, got {default_handler}"
)
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
"""Set tracing callback manager."""
handler = SharedLangChainTracer()
callback = get_callback_manager()
callback.set_handlers([handler, StdOutCallbackHandler()])
if session_name is None:
handler.load_default_session()
else:
try:
handler.load_session(session_name)
except Exception:
raise ValueError(f"session {session_name} not found")
@contextmanager

@ -1,25 +1,34 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
from typing import Any, Dict, List, Union
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(BaseModel, ABC):
class BaseCallbackHandler(ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
ignore_llm: bool = False
ignore_chain: bool = False
ignore_agent: bool = False
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return False
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return False
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return False
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@ -31,7 +40,9 @@ class BaseCallbackHandler(BaseModel, ABC):
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
@abstractmethod
@ -45,7 +56,9 @@ class BaseCallbackHandler(BaseModel, ABC):
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
@abstractmethod
@ -59,7 +72,9 @@ class BaseCallbackHandler(BaseModel, ABC):
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
@abstractmethod
@ -82,15 +97,21 @@ class BaseCallbackManager(BaseCallbackHandler, ABC):
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
@abstractmethod
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler])
@abstractmethod
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
handlers: List[BaseCallbackHandler]
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
def on_llm_start(
self,
@ -115,7 +136,10 @@ class CallbackManager(BaseCallbackManager):
handler.on_llm_end(response)
def on_llm_error(
self, error: Exception, verbose: bool = False, **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
@ -146,7 +170,10 @@ class CallbackManager(BaseCallbackManager):
handler.on_chain_end(outputs)
def on_chain_error(
self, error: Exception, verbose: bool = False, **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain errors."""
for handler in self.handlers:
@ -175,7 +202,10 @@ class CallbackManager(BaseCallbackManager):
handler.on_tool_end(output, **kwargs)
def on_tool_error(
self, error: Exception, verbose: bool = False, **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool errors."""
for handler in self.handlers:
@ -206,6 +236,6 @@ class CallbackManager(BaseCallbackManager):
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
self.handlers = [handler]
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers

@ -1,5 +1,5 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -29,7 +29,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):
if "total_tokens" in token_usage:
self.total_tokens += token_usage["total_tokens"]
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
@ -43,7 +45,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""Print out that we finished a chain."""
pass
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
@ -68,7 +72,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""If not the final action, print out observation."""
pass
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass

@ -1,7 +1,7 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from langchain.callbacks.base import (
BaseCallbackHandler,
@ -46,7 +46,9 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock:
self._callback_manager.on_llm_end(response, **kwargs)
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error, **kwargs)
@ -63,7 +65,9 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock:
self._callback_manager.on_chain_end(outputs, **kwargs)
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error, **kwargs)
@ -80,7 +84,9 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock:
self._callback_manager.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error, **kwargs)
@ -105,7 +111,7 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock:
self._callback_manager.remove_handler(callback)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
with self._lock:
self._callback_manager.handlers = [handler]
self._callback_manager.handlers = handlers

@ -1,5 +1,5 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
@ -19,7 +19,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
pass
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
@ -34,7 +36,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
@ -61,7 +65,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
print_text(output, color=color)
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass

@ -1,5 +1,5 @@
"""Callback Handler that logs to streamlit."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import streamlit as st
@ -22,7 +22,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
pass
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
@ -37,7 +39,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
"""Print out that we finished a chain."""
st.write("Finished chain.")
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
@ -62,7 +66,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
st.write(f"{observation_prefix}{output}")
st.write(llm_prefix)
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass

@ -0,0 +1,12 @@
"""Tracers that record execution of LangChain runs."""
from langchain.callbacks.tracers.base import SharedTracer, Tracer
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
"""Shared tracer that records LangChain execution to LangChain endpoint."""
class LangChainTracer(Tracer, BaseLangChainTracer):
"""Tracer that records LangChain execution to LangChain endpoint."""

@ -0,0 +1,334 @@
"""Base interfaces for tracing runs."""
from __future__ import annotations
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.shared import Singleton
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
ToolRun,
TracerSession,
TracerSessionCreate,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
class TracerException(Exception):
"""Base class for exceptions in tracers module."""
class BaseTracer(BaseCallbackHandler, ABC):
"""Base interface for tracers."""
@abstractmethod
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
@abstractmethod
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
@abstractmethod
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
@abstractmethod
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
"""NOT thread safe, do not call this method from multiple threads."""
session_create = TracerSessionCreate(name=name, extra=kwargs)
session = self._persist_session(session_create)
self._session = session
return session
@abstractmethod
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session and set it as the Tracer's session."""
@abstractmethod
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
@property
@abstractmethod
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
@property
@abstractmethod
def _execution_order(self) -> int:
"""Get the execution order for a run."""
@_execution_order.setter
@abstractmethod
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
@property
@abstractmethod
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
@_session.setter
@abstractmethod
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Start a trace for a run."""
self._execution_order += 1
if self._stack:
if not (
isinstance(self._stack[-1], ChainRun)
or isinstance(self._stack[-1], ToolRun)
):
raise TracerException(
f"Nested {run.__class__.__name__} can only be"
f" logged inside a ChainRun or ToolRun"
)
self._add_child_run(self._stack[-1], run)
self._stack.append(run)
def _end_trace(self) -> None:
"""End a trace for a run."""
run = self._stack.pop()
if not self._stack:
self._execution_order = 1
self._persist_run(run)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Start a trace for an LLM run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
llm_run = LLMRun(
serialized=serialized,
prompts=prompts,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
session_id=self._session.id,
id=self._generate_id(),
)
self._start_trace(llm_run)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""End a trace for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):
raise TracerException("No LLMRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].response = response
self._end_trace()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Handle an error for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):
raise TracerException("No LLMRun found to be traced")
self._stack[-1].error = repr(error)
self._stack[-1].end_time = datetime.utcnow()
self._end_trace()
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Start a trace for a chain run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
chain_run = ChainRun(
serialized=serialized,
inputs=inputs,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
child_runs=[],
session_id=self._session.id,
id=self._generate_id(),
)
self._start_trace(chain_run)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""End a trace for a chain run."""
if not self._stack or not isinstance(self._stack[-1], ChainRun):
raise TracerException("No ChainRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].outputs = outputs
self._end_trace()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Handle an error for a chain run."""
if not self._stack or not isinstance(self._stack[-1], ChainRun):
raise TracerException("No ChainRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].error = repr(error)
self._end_trace()
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Start a trace for a tool run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
tool_run = ToolRun(
serialized=serialized,
action=action.tool,
tool_input=action.tool_input,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
child_runs=[],
session_id=self._session.id,
id=self._generate_id(),
)
self._start_trace(tool_run)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""End a trace for a tool run."""
if not self._stack or not isinstance(self._stack[-1], ToolRun):
raise TracerException("No ToolRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].output = output
self._end_trace()
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Handle an error for a tool run."""
if not self._stack or not isinstance(self._stack[-1], ToolRun):
raise TracerException("No ToolRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].error = repr(error)
self._end_trace()
def on_text(self, text: str, **kwargs: Any) -> None:
"""Handle a text message."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Handle an agent finish message."""
pass
class Tracer(BaseTracer, ABC):
"""A non-thread safe implementation of the BaseTracer interface."""
def __init__(self) -> None:
"""Initialize a tracer."""
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
self._tracer_execution_order = 1
self._tracer_session: Optional[TracerSession] = None
@property
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
return self._tracer_stack
@property
def _execution_order(self) -> int:
"""Get the execution order for a run."""
return self._tracer_execution_order
@_execution_order.setter
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
self._tracer_execution_order = value
@property
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
return self._tracer_session
@_session.setter
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
if self._stack:
raise TracerException(
"Cannot set a session while a trace is being recorded"
)
self._tracer_session = value
@dataclass
class TracerStack(threading.local):
"""A stack of runs used for logging."""
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
execution_order: int = 1
class SharedTracer(Singleton, BaseTracer, ABC):
"""A thread-safe Singleton implementation of BaseTracer."""
_tracer_stack = TracerStack()
_tracer_session = None
@property
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
return self._tracer_stack.stack
@property
def _execution_order(self) -> int:
"""Get the execution order for a run."""
return self._tracer_stack.execution_order
@_execution_order.setter
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
self._tracer_stack.execution_order = value
@property
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
return self._tracer_session
@_session.setter
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
with self._lock:
# TODO: currently, we are only checking current thread's stack.
# Need to make sure that we are not in the middle of a trace
# in any thread.
if self._stack:
raise TracerException(
"Cannot set a session while a trace is being recorded"
)
self._tracer_session = value

@ -0,0 +1,112 @@
"""A Tracer implementation that records to LangChain endpoint."""
from __future__ import annotations
import logging
import os
from abc import ABC
from typing import Any, Dict, Optional, Union
import requests
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
ToolRun,
TracerSession,
TracerSessionCreate,
)
class BaseLangChainTracer(BaseTracer, ABC):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
always_verbose: bool = True
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"
try:
requests.post(
endpoint,
data=run.json(),
headers=self._headers,
)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSession(id=r.json()["id"], **session_create.dict())
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = TracerSession(id=1, **session_create.dict())
return session
def load_session(self, session_name: str) -> TracerSession:
"""Load a session from the tracer."""
try:
r = requests.get(
f"{self._endpoint}/sessions?name={session_name}",
headers=self._headers,
)
tracer_session = TracerSession(**r.json()[0])
self._session = tracer_session
return tracer_session
except Exception as e:
logging.warning(
f"Failed to load session {session_name}, using empty session: {e}"
)
tracer_session = TracerSession(id=1)
self._session = tracer_session
return tracer_session
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
try:
r = requests.get(
f"{self._endpoint}/sessions",
headers=self._headers,
)
# Use the first session result
tracer_session = TracerSession(**r.json()[0])
self._session = tracer_session
return tracer_session
except Exception as e:
logging.warning(f"Failed to default session, using empty session: {e}")
tracer_session = TracerSession(id=1)
self._session = tracer_session
return tracer_session
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
if isinstance(child_run, LLMRun):
parent_run.child_llm_runs.append(child_run)
elif isinstance(child_run, ChainRun):
parent_run.child_chain_runs.append(child_run)
else:
parent_run.child_tool_runs.append(child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return None

@ -0,0 +1,76 @@
"""Schemas for tracers."""
from __future__ import annotations
import datetime
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from langchain.schema import LLMResult
class TracerSessionBase(BaseModel):
"""Base class for TracerSession."""
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
class TracerSessionCreate(TracerSessionBase):
"""Create class for TracerSession."""
pass
class TracerSession(TracerSessionBase):
"""TracerSession schema."""
id: int
class BaseRun(BaseModel):
"""Base class for Run."""
id: Optional[Union[int, str]] = None
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None
execution_order: int
serialized: Dict[str, Any]
session_id: int
error: Optional[str] = None
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
class ChainRun(BaseRun):
"""Class for ChainRun."""
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = None
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
class ToolRun(BaseRun):
"""Class for ToolRun."""
tool_input: str
output: Optional[str] = None
action: str
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()

@ -150,7 +150,7 @@ class Chain(BaseModel, ABC):
)
try:
outputs = self._call(inputs)
except Exception as e:
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)

@ -0,0 +1,29 @@
version: '3'
services:
langchain-frontend:
image: notlangchain/langchainplus-frontend:latest
ports:
- 4173:4173
environment:
- BACKEND_URL=http://langchain-backend:8000
- PUBLIC_BASE_URL=http://localhost:8000
- PUBLIC_DEV_MODE=true
depends_on:
- langchain-backend
langchain-backend:
image: notlangchain/langchainplus:latest
environment:
- PORT=8000
- LANGCHAIN_ENV=local
ports:
- 8000:8000
depends_on:
- langchain-db
langchain-db:
image: postgres:14.1
environment:
- POSTGRES_PASSWORD=postgres
- POSTGRES_USER=postgres
- POSTGRES_DB=postgres
ports:
- 5432:5432

@ -74,7 +74,7 @@ class BaseLLM(BaseModel, ABC):
)
try:
output = self._generate(prompts, stop=stop)
except Exception as e:
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
@ -97,7 +97,7 @@ class BaseLLM(BaseModel, ABC):
)
try:
new_results = self._generate(missing_prompts, stop=stop)
except Exception as e:
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)

@ -1,7 +1,10 @@
"""Common schema objects."""
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional
from dataclasses_json import dataclass_json
class AgentAction(NamedTuple):
"""Agent's action to take."""
@ -18,7 +21,9 @@ class AgentFinish(NamedTuple):
log: str
class Generation(NamedTuple):
@dataclass_json
@dataclass
class Generation:
"""Output of a single generation."""
text: str
@ -30,7 +35,9 @@ class Generation(NamedTuple):
# TODO: add log probs
class LLMResult(NamedTuple):
@dataclass_json
@dataclass
class LLMResult:
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]

@ -0,0 +1,14 @@
"""Script to run langchain-server locally using docker-compose."""
import subprocess
from pathlib import Path
def main() -> None:
"""Run the langchain server locally."""
p = Path(__file__).absolute().parent / "docker-compose.yaml"
subprocess.run(["docker-compose", "-f", str(p), "pull"])
subprocess.run(["docker-compose", "-f", str(p), "up"])
if __name__ == "__main__":
main()

97
poetry.lock generated

@ -784,6 +784,26 @@ files = [
{file = "cymem-2.0.7.tar.gz", hash = "sha256:e6034badb5dd4e10344211c81f16505a55553a7164adc314c75bd80cf07e57a8"},
]
[[package]]
name = "dataclasses-json"
version = "0.5.7"
description = "Easily serialize dataclasses to and from JSON"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "dataclasses-json-0.5.7.tar.gz", hash = "sha256:c2c11bc8214fbf709ffc369d11446ff6945254a7f09128154a7620613d8fda90"},
{file = "dataclasses_json-0.5.7-py3-none-any.whl", hash = "sha256:bc285b5f892094c3a53d558858a88553dd6a61a11ab1a8128a0e554385dcc5dd"},
]
[package.dependencies]
marshmallow = ">=3.3.0,<4.0.0"
marshmallow-enum = ">=1.5.1,<2.0.0"
typing-inspect = ">=0.4.0"
[package.extras]
dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest (>=6.2.3)", "simplejson", "types-dataclasses"]
[[package]]
name = "debugpy"
version = "1.6.5"
@ -1152,6 +1172,21 @@ files = [
{file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"},
]
[[package]]
name = "freezegun"
version = "1.2.2"
description = "Let your Python tests travel through time"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "freezegun-1.2.2-py3-none-any.whl", hash = "sha256:ea1b963b993cb9ea195adbd893a48d573fda951b0da64f60883d7e988b606c9f"},
{file = "freezegun-1.2.2.tar.gz", hash = "sha256:cd22d1ba06941384410cd967d8a99d5ae2442f57dfafeff2fda5de8dc5c05446"},
]
[package.dependencies]
python-dateutil = ">=2.7"
[[package]]
name = "google-api-core"
version = "2.11.0"
@ -2235,7 +2270,6 @@ files = [
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"},
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"},
{file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"},
{file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"},
{file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"},
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"},
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"},
@ -2245,7 +2279,6 @@ files = [
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"},
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"},
{file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"},
{file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"},
{file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"},
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"},
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"},
@ -2404,6 +2437,42 @@ files = [
{file = "MarkupSafe-2.1.2.tar.gz", hash = "sha256:abcabc8c2b26036d62d4c746381a6f7cf60aafcc653198ad678306986b09450d"},
]
[[package]]
name = "marshmallow"
version = "3.19.0"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "marshmallow-3.19.0-py3-none-any.whl", hash = "sha256:93f0958568da045b0021ec6aeb7ac37c81bfcccbb9a0e7ed8559885070b3a19b"},
{file = "marshmallow-3.19.0.tar.gz", hash = "sha256:90032c0fd650ce94b6ec6dc8dfeb0e3ff50c144586462c389b81a07205bedb78"},
]
[package.dependencies]
packaging = ">=17.0"
[package.extras]
dev = ["flake8 (==5.0.4)", "flake8-bugbear (==22.10.25)", "mypy (==0.990)", "pre-commit (>=2.4,<3.0)", "pytest", "pytz", "simplejson", "tox"]
docs = ["alabaster (==0.7.12)", "autodocsumm (==0.2.9)", "sphinx (==5.3.0)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"]
lint = ["flake8 (==5.0.4)", "flake8-bugbear (==22.10.25)", "mypy (==0.990)", "pre-commit (>=2.4,<3.0)"]
tests = ["pytest", "pytz", "simplejson"]
[[package]]
name = "marshmallow-enum"
version = "1.5.1"
description = "Enum field for Marshmallow"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "marshmallow-enum-1.5.1.tar.gz", hash = "sha256:38e697e11f45a8e64b4a1e664000897c659b60aa57bfa18d44e226a9920b6e58"},
{file = "marshmallow_enum-1.5.1-py2.py3-none-any.whl", hash = "sha256:57161ab3dbfde4f57adeb12090f39592e992b9c86d206d02f6bd03ebec60f072"},
]
[package.dependencies]
marshmallow = ">=2.0.0"
[[package]]
name = "matplotlib-inline"
version = "0.1.6"
@ -2580,7 +2649,7 @@ reports = ["lxml"]
name = "mypy-extensions"
version = "0.4.3"
description = "Experimental type system extensions for programs checked with the mypy typechecker."
category = "dev"
category = "main"
optional = false
python-versions = "*"
files = [
@ -4941,14 +5010,11 @@ files = [
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
{file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"},
{file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
{file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"},
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
@ -4957,7 +5023,6 @@ files = [
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
@ -5284,6 +5349,22 @@ files = [
{file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"},
]
[[package]]
name = "typing-inspect"
version = "0.8.0"
description = "Runtime inspection utilities for typing module."
category = "main"
optional = false
python-versions = "*"
files = [
{file = "typing_inspect-0.8.0-py3-none-any.whl", hash = "sha256:5fbf9c1e65d4fa01e701fe12a5bca6c6e08a4ffd5bc60bfac028253a447c5188"},
{file = "typing_inspect-0.8.0.tar.gz", hash = "sha256:8b1ff0c400943b6145df8119c41c244ca8207f1f10c9c057aeed1560e4806e3d"},
]
[package.dependencies]
mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4"
[[package]]
name = "uri-template"
version = "1.2.0"
@ -5585,4 +5666,4 @@ llms = ["manifest-ml", "torch", "transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "2da2c13a9a1572e4f7792c336c57f7087d79c21a1370d6f0534bcfa80abcad30"
content-hash = "537ede877b299a8800eb26a428607be97b59f89435259abda2c7cc86092306c6"

@ -7,6 +7,9 @@ license = "MIT"
readme = "README.md"
repository = "https://www.github.com/hwchase17/langchain"
[tool.poetry.scripts]
langchain-server = "langchain.server:main"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
pydantic = "^1"
@ -31,6 +34,7 @@ weaviate-client = {version = "^3", optional = true}
google-api-python-client = {version = "2.70.0", optional = true}
wolframalpha = {version = "5.0.0", optional = true}
qdrant-client = {version = "^0.11.7", optional = true}
dataclasses-json = "^0.5.7"
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"
@ -52,6 +56,7 @@ pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.6.6"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
[tool.poetry.group.lint.dependencies]
flake8-docstrings = "^1.6.0"
@ -59,7 +64,7 @@ black = "^22.10.0"
isort = "^5.10.1"
flake8 = "^6.0.0"
types-toml = "^0.10.8.1"
types-redis = "^4.3.21.6"
types-redis = "^4.3.21.6"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"

@ -1,17 +1,43 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Fake callback handler for testing."""
starts: int = 0
ends: int = 0
errors: int = 0
text: int = 0
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
always_verbose_: bool = False
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@ -23,7 +49,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.errors += 1
@ -37,7 +65,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
self.errors += 1
@ -51,7 +81,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
self.errors += 1

@ -8,6 +8,31 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
for handler in handlers:
if handler.always_verbose:
assert handler.starts == 3
assert handler.ends == 4
assert handler.errors == 3
else:
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0
def _test_callback_manager_pass_in_verbose(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [], verbose=True)
@ -19,7 +44,7 @@ def _test_callback_manager(
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 4
@ -27,17 +52,25 @@ def _test_callback_manager(
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=False)
manager = CallbackManager([handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_callback_manager_pass_in_verbose() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
manager = CallbackManager([handler1, handler2])
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
def test_ignore_llm() -> None:
"""Test ignore llm param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_llm=True)
handler2 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_llm_start({}, [], verbose=True)
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
@ -52,8 +85,8 @@ def test_ignore_llm() -> None:
def test_ignore_chain() -> None:
"""Test ignore chain param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_chain=True)
handler2 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
@ -68,8 +101,8 @@ def test_ignore_chain() -> None:
def test_ignore_agent() -> None:
"""Test ignore agent param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_agent=True)
handler2 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_end("", verbose=True)
@ -90,7 +123,7 @@ def test_shared_callback_manager() -> None:
assert manager1 is manager2
handler1 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)

@ -0,0 +1 @@
"""Tests for correct functioning of tracers."""

@ -0,0 +1,530 @@
"""Test Tracer classes."""
from __future__ import annotations
import threading
from datetime import datetime
from typing import List, Optional, Union
import pytest
from freezegun import freeze_time
from langchain.callbacks.tracers.base import (
BaseTracer,
ChainRun,
LLMRun,
SharedTracer,
ToolRun,
Tracer,
TracerException,
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.schema import AgentAction, LLMResult
TEST_SESSION_ID = 2023
@freeze_time("2023-01-01")
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
return ChainRun(
id=None,
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_runs=[
ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
serialized={},
tool_input="test",
output="test",
action="action",
session_id=TEST_SESSION_ID,
error=None,
child_runs=[
LLMRun(
id=None,
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
)
],
),
LLMRun(
id=None,
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
),
],
)
def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_end("test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_chain_end(outputs={})
def _add_child_run(
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
def _generate_id() -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return None
def load_session(session_name: str) -> TracerSession:
"""Load a tracing session."""
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
def _persist_session(session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return TracerSession(id=TEST_SESSION_ID, **session.dict())
def load_default_session() -> TracerSession:
"""Load a tracing session."""
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
class FakeTracer(Tracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
self.runs.append(run)
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
_add_child_run(parent_run, child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return _generate_id()
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load a tracing session."""
return load_default_session()
class FakeSharedTracer(SharedTracer):
"""Fake shared tracer that records LangChain execution."""
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
with self._lock:
self.runs.append(run)
def remove_runs(self) -> None:
"""Remove all runs."""
with self._lock:
self.runs = []
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
_add_child_run(parent_run, child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return _generate_id()
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load a tracing session."""
return load_default_session()
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
compare_run = LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_session() -> None:
"""Test tracer on an LLM run without a session."""
tracer = FakeTracer()
with pytest.raises(TracerException):
tracer.on_llm_start(serialized={}, prompts=[])
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult([[]]))
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
compare_run = LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
compare_run = ChainRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_chain_end(outputs={})
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
compare_run = ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
tool_input="test",
output="test",
action="action",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_end("test")
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
_perform_nested_run(tracer)
assert tracer.runs == [_get_compare_run()]
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
compare_run = LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
compare_run = ChainRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
inputs={},
outputs=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_chain_error(exception)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
compare_run = ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
tool_input="test",
output=None,
action="action",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_error(exception)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
tracer.on_chain_error(exception)
compare_run = ChainRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_runs=[
LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="action",
child_runs=[
LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
),
],
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
)
assert tracer.runs == [compare_run] * 3
@freeze_time("2023-01-01")
def test_shared_tracer_nested_run() -> None:
"""Test shared tracer on a nested run."""
tracer = FakeSharedTracer()
tracer.new_session()
tracer.remove_runs()
_perform_nested_run(tracer)
assert tracer.runs == [_get_compare_run()]
@freeze_time("2023-01-01")
def test_shared_tracer_nested_run_multithreaded() -> None:
"""Test shared tracer on a nested run."""
tracer = FakeSharedTracer()
tracer.remove_runs()
tracer.new_session()
threads = []
num_threads = 10
for _ in range(num_threads):
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
assert tracer.runs == [_get_compare_run()] * num_threads
Loading…
Cancel
Save