diff --git a/langchain/__init__.py b/langchain/__init__.py index c0bddaa9..9b9bfb35 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -4,7 +4,11 @@ from typing import Optional from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache -from langchain.callbacks import set_default_callback_manager, set_handler +from langchain.callbacks import ( + set_default_callback_manager, + set_handler, + set_tracing_callback_manager, +) from langchain.chains import ( ConversationChain, LLMBashChain, @@ -68,4 +72,5 @@ __all__ = [ "QAWithSourcesChain", "PALChain", "set_handler", + "set_tracing_callback_manager", ] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 99fb2c1e..29010663 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -284,7 +284,7 @@ class AgentExecutor(Chain, BaseModel): observation = tool.func(output.tool_input) color = color_mapping[output.tool] return_direct = tool.return_direct - except Exception as e: + except (KeyboardInterrupt, Exception) as e: self.callback_manager.on_tool_error(e, verbose=self.verbose) raise e else: diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index 501de6d3..2bc79d6a 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,11 +1,13 @@ """Callback handlers that allow listening to events in LangChain.""" +import os from contextlib import contextmanager -from typing import Generator +from typing import Generator, Optional from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.tracers import SharedLangChainTracer def get_callback_manager() -> BaseCallbackManager: @@ -21,7 +23,31 @@ def set_handler(handler: BaseCallbackHandler) -> None: def set_default_callback_manager() -> None: """Set default callback manager.""" - set_handler(StdOutCallbackHandler()) + default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout") + if default_handler == "stdout": + set_handler(StdOutCallbackHandler()) + elif default_handler == "langchain": + session = os.environ.get("LANGCHAIN_SESSION") + set_tracing_callback_manager(session) + else: + raise ValueError( + f"LANGCHAIN_HANDLER should be one of `stdout` " + f"or `langchain`, got {default_handler}" + ) + + +def set_tracing_callback_manager(session_name: Optional[str] = None) -> None: + """Set tracing callback manager.""" + handler = SharedLangChainTracer() + callback = get_callback_manager() + callback.set_handlers([handler, StdOutCallbackHandler()]) + if session_name is None: + handler.load_default_session() + else: + try: + handler.load_session(session_name) + except Exception: + raise ValueError(f"session {session_name} not found") @contextmanager diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index d04bc1a6..aeb409f9 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -1,25 +1,34 @@ """Base callback handler that can be used to handle callbacks from langchain.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List - -from pydantic import BaseModel +from typing import Any, Dict, List, Union from langchain.schema import AgentAction, AgentFinish, LLMResult -class BaseCallbackHandler(BaseModel, ABC): +class BaseCallbackHandler(ABC): """Base callback handler that can be used to handle callbacks from langchain.""" - ignore_llm: bool = False - ignore_chain: bool = False - ignore_agent: bool = False - @property def always_verbose(self) -> bool: """Whether to call verbose callbacks even if verbose is False.""" return False + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return False + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return False + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return False + @abstractmethod def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any @@ -31,7 +40,9 @@ class BaseCallbackHandler(BaseModel, ABC): """Run when LLM ends running.""" @abstractmethod - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when LLM errors.""" @abstractmethod @@ -45,7 +56,9 @@ class BaseCallbackHandler(BaseModel, ABC): """Run when chain ends running.""" @abstractmethod - def on_chain_error(self, error: Exception, **kwargs: Any) -> None: + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when chain errors.""" @abstractmethod @@ -59,7 +72,9 @@ class BaseCallbackHandler(BaseModel, ABC): """Run when tool ends running.""" @abstractmethod - def on_tool_error(self, error: Exception, **kwargs: Any) -> None: + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when tool errors.""" @abstractmethod @@ -82,15 +97,21 @@ class BaseCallbackManager(BaseCallbackHandler, ABC): def remove_handler(self, handler: BaseCallbackHandler) -> None: """Remove a handler from the callback manager.""" - @abstractmethod def set_handler(self, handler: BaseCallbackHandler) -> None: """Set handler as the only handler on the callback manager.""" + self.set_handlers([handler]) + + @abstractmethod + def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: + """Set handlers as the only handlers on the callback manager.""" class CallbackManager(BaseCallbackManager): """Callback manager that can be used to handle callbacks from langchain.""" - handlers: List[BaseCallbackHandler] + def __init__(self, handlers: List[BaseCallbackHandler]) -> None: + """Initialize callback manager.""" + self.handlers: List[BaseCallbackHandler] = handlers def on_llm_start( self, @@ -115,7 +136,10 @@ class CallbackManager(BaseCallbackManager): handler.on_llm_end(response) def on_llm_error( - self, error: Exception, verbose: bool = False, **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + verbose: bool = False, + **kwargs: Any ) -> None: """Run when LLM errors.""" for handler in self.handlers: @@ -146,7 +170,10 @@ class CallbackManager(BaseCallbackManager): handler.on_chain_end(outputs) def on_chain_error( - self, error: Exception, verbose: bool = False, **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + verbose: bool = False, + **kwargs: Any ) -> None: """Run when chain errors.""" for handler in self.handlers: @@ -175,7 +202,10 @@ class CallbackManager(BaseCallbackManager): handler.on_tool_end(output, **kwargs) def on_tool_error( - self, error: Exception, verbose: bool = False, **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + verbose: bool = False, + **kwargs: Any ) -> None: """Run when tool errors.""" for handler in self.handlers: @@ -206,6 +236,6 @@ class CallbackManager(BaseCallbackManager): """Remove a handler from the callback manager.""" self.handlers.remove(handler) - def set_handler(self, handler: BaseCallbackHandler) -> None: - """Set handler as the only handler on the callback manager.""" - self.handlers = [handler] + def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: + """Set handlers as the only handlers on the callback manager.""" + self.handlers = handlers diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index b1471412..32b26591 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -1,5 +1,5 @@ """Callback Handler that prints to std out.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -29,7 +29,9 @@ class OpenAICallbackHandler(BaseCallbackHandler): if "total_tokens" in token_usage: self.total_tokens += token_usage["total_tokens"] - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass @@ -43,7 +45,9 @@ class OpenAICallbackHandler(BaseCallbackHandler): """Print out that we finished a chain.""" pass - def on_chain_error(self, error: Exception, **kwargs: Any) -> None: + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass @@ -68,7 +72,9 @@ class OpenAICallbackHandler(BaseCallbackHandler): """If not the final action, print out observation.""" pass - def on_tool_error(self, error: Exception, **kwargs: Any) -> None: + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index 3ec7a686..4b0772cd 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -1,7 +1,7 @@ """A shared CallbackManager.""" import threading -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from langchain.callbacks.base import ( BaseCallbackHandler, @@ -46,7 +46,9 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_llm_end(response, **kwargs) - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when LLM errors.""" with self._lock: self._callback_manager.on_llm_error(error, **kwargs) @@ -63,7 +65,9 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_chain_end(outputs, **kwargs) - def on_chain_error(self, error: Exception, **kwargs: Any) -> None: + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when chain errors.""" with self._lock: self._callback_manager.on_chain_error(error, **kwargs) @@ -80,7 +84,9 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_tool_end(output, **kwargs) - def on_tool_error(self, error: Exception, **kwargs: Any) -> None: + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when tool errors.""" with self._lock: self._callback_manager.on_tool_error(error, **kwargs) @@ -105,7 +111,7 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.remove_handler(callback) - def set_handler(self, handler: BaseCallbackHandler) -> None: - """Set handler as the only handler on the callback manager.""" + def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: + """Set handlers as the only handlers on the callback manager.""" with self._lock: - self._callback_manager.handlers = [handler] + self._callback_manager.handlers = handlers diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index ff8ea2f4..caf8e6b1 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -1,5 +1,5 @@ """Callback Handler that prints to std out.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -19,7 +19,9 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Do nothing.""" pass - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass @@ -34,7 +36,9 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Print out that we finished a chain.""" print("\n\033[1m> Finished chain.\033[0m") - def on_chain_error(self, error: Exception, **kwargs: Any) -> None: + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass @@ -61,7 +65,9 @@ class StdOutCallbackHandler(BaseCallbackHandler): print_text(output, color=color) print_text(f"\n{llm_prefix}") - def on_tool_error(self, error: Exception, **kwargs: Any) -> None: + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py index 99bc8d96..1451aac7 100644 --- a/langchain/callbacks/streamlit.py +++ b/langchain/callbacks/streamlit.py @@ -1,5 +1,5 @@ """Callback Handler that logs to streamlit.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import streamlit as st @@ -22,7 +22,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler): """Do nothing.""" pass - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass @@ -37,7 +39,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler): """Print out that we finished a chain.""" st.write("Finished chain.") - def on_chain_error(self, error: Exception, **kwargs: Any) -> None: + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass @@ -62,7 +66,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler): st.write(f"{observation_prefix}{output}") st.write(llm_prefix) - def on_tool_error(self, error: Exception, **kwargs: Any) -> None: + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Do nothing.""" pass diff --git a/langchain/callbacks/tracers/__init__.py b/langchain/callbacks/tracers/__init__.py new file mode 100644 index 00000000..8db5367f --- /dev/null +++ b/langchain/callbacks/tracers/__init__.py @@ -0,0 +1,12 @@ +"""Tracers that record execution of LangChain runs.""" + +from langchain.callbacks.tracers.base import SharedTracer, Tracer +from langchain.callbacks.tracers.langchain import BaseLangChainTracer + + +class SharedLangChainTracer(SharedTracer, BaseLangChainTracer): + """Shared tracer that records LangChain execution to LangChain endpoint.""" + + +class LangChainTracer(Tracer, BaseLangChainTracer): + """Tracer that records LangChain execution to LangChain endpoint.""" diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py new file mode 100644 index 00000000..76b11639 --- /dev/null +++ b/langchain/callbacks/tracers/base.py @@ -0,0 +1,334 @@ +"""Base interfaces for tracing runs.""" +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.shared import Singleton +from langchain.callbacks.tracers.schemas import ( + ChainRun, + LLMRun, + ToolRun, + TracerSession, + TracerSessionCreate, +) +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class TracerException(Exception): + """Base class for exceptions in tracers module.""" + + +class BaseTracer(BaseCallbackHandler, ABC): + """Base interface for tracers.""" + + @abstractmethod + def _add_child_run( + self, + parent_run: Union[ChainRun, ToolRun], + child_run: Union[LLMRun, ChainRun, ToolRun], + ) -> None: + """Add child run to a chain run or tool run.""" + + @abstractmethod + def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + + @abstractmethod + def _persist_session(self, session: TracerSessionCreate) -> TracerSession: + """Persist a tracing session.""" + + @abstractmethod + def _generate_id(self) -> Optional[Union[int, str]]: + """Generate an id for a run.""" + + def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession: + """NOT thread safe, do not call this method from multiple threads.""" + session_create = TracerSessionCreate(name=name, extra=kwargs) + session = self._persist_session(session_create) + self._session = session + return session + + @abstractmethod + def load_session(self, session_name: str) -> TracerSession: + """Load a tracing session and set it as the Tracer's session.""" + + @abstractmethod + def load_default_session(self) -> TracerSession: + """Load the default tracing session and set it as the Tracer's session.""" + + @property + @abstractmethod + def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: + """Get the tracer stack.""" + + @property + @abstractmethod + def _execution_order(self) -> int: + """Get the execution order for a run.""" + + @_execution_order.setter + @abstractmethod + def _execution_order(self, value: int) -> None: + """Set the execution order for a run.""" + + @property + @abstractmethod + def _session(self) -> Optional[TracerSession]: + """Get the tracing session.""" + + @_session.setter + @abstractmethod + def _session(self, value: TracerSession) -> None: + """Set the tracing session.""" + + def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + """Start a trace for a run.""" + self._execution_order += 1 + + if self._stack: + if not ( + isinstance(self._stack[-1], ChainRun) + or isinstance(self._stack[-1], ToolRun) + ): + raise TracerException( + f"Nested {run.__class__.__name__} can only be" + f" logged inside a ChainRun or ToolRun" + ) + self._add_child_run(self._stack[-1], run) + self._stack.append(run) + + def _end_trace(self) -> None: + """End a trace for a run.""" + run = self._stack.pop() + if not self._stack: + self._execution_order = 1 + self._persist_run(run) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Start a trace for an LLM run.""" + if self._session is None: + raise TracerException( + "Initialize a session with `new_session()` before starting a trace." + ) + + llm_run = LLMRun( + serialized=serialized, + prompts=prompts, + extra=kwargs, + start_time=datetime.utcnow(), + execution_order=self._execution_order, + session_id=self._session.id, + id=self._generate_id(), + ) + self._start_trace(llm_run) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """End a trace for an LLM run.""" + if not self._stack or not isinstance(self._stack[-1], LLMRun): + raise TracerException("No LLMRun found to be traced") + + self._stack[-1].end_time = datetime.utcnow() + self._stack[-1].response = response + + self._end_trace() + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Handle an error for an LLM run.""" + if not self._stack or not isinstance(self._stack[-1], LLMRun): + raise TracerException("No LLMRun found to be traced") + + self._stack[-1].error = repr(error) + self._stack[-1].end_time = datetime.utcnow() + + self._end_trace() + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Start a trace for a chain run.""" + if self._session is None: + raise TracerException( + "Initialize a session with `new_session()` before starting a trace." + ) + + chain_run = ChainRun( + serialized=serialized, + inputs=inputs, + extra=kwargs, + start_time=datetime.utcnow(), + execution_order=self._execution_order, + child_runs=[], + session_id=self._session.id, + id=self._generate_id(), + ) + self._start_trace(chain_run) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """End a trace for a chain run.""" + if not self._stack or not isinstance(self._stack[-1], ChainRun): + raise TracerException("No ChainRun found to be traced") + + self._stack[-1].end_time = datetime.utcnow() + self._stack[-1].outputs = outputs + + self._end_trace() + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Handle an error for a chain run.""" + if not self._stack or not isinstance(self._stack[-1], ChainRun): + raise TracerException("No ChainRun found to be traced") + + self._stack[-1].end_time = datetime.utcnow() + self._stack[-1].error = repr(error) + + self._end_trace() + + def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Start a trace for a tool run.""" + if self._session is None: + raise TracerException( + "Initialize a session with `new_session()` before starting a trace." + ) + + tool_run = ToolRun( + serialized=serialized, + action=action.tool, + tool_input=action.tool_input, + extra=kwargs, + start_time=datetime.utcnow(), + execution_order=self._execution_order, + child_runs=[], + session_id=self._session.id, + id=self._generate_id(), + ) + self._start_trace(tool_run) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """End a trace for a tool run.""" + if not self._stack or not isinstance(self._stack[-1], ToolRun): + raise TracerException("No ToolRun found to be traced") + + self._stack[-1].end_time = datetime.utcnow() + self._stack[-1].output = output + + self._end_trace() + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Handle an error for a tool run.""" + if not self._stack or not isinstance(self._stack[-1], ToolRun): + raise TracerException("No ToolRun found to be traced") + + self._stack[-1].end_time = datetime.utcnow() + self._stack[-1].error = repr(error) + + self._end_trace() + + def on_text(self, text: str, **kwargs: Any) -> None: + """Handle a text message.""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Handle an agent finish message.""" + pass + + +class Tracer(BaseTracer, ABC): + """A non-thread safe implementation of the BaseTracer interface.""" + + def __init__(self) -> None: + """Initialize a tracer.""" + self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = [] + self._tracer_execution_order = 1 + self._tracer_session: Optional[TracerSession] = None + + @property + def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: + """Get the tracer stack.""" + return self._tracer_stack + + @property + def _execution_order(self) -> int: + """Get the execution order for a run.""" + return self._tracer_execution_order + + @_execution_order.setter + def _execution_order(self, value: int) -> None: + """Set the execution order for a run.""" + self._tracer_execution_order = value + + @property + def _session(self) -> Optional[TracerSession]: + """Get the tracing session.""" + return self._tracer_session + + @_session.setter + def _session(self, value: TracerSession) -> None: + """Set the tracing session.""" + if self._stack: + raise TracerException( + "Cannot set a session while a trace is being recorded" + ) + self._tracer_session = value + + +@dataclass +class TracerStack(threading.local): + """A stack of runs used for logging.""" + + stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list) + execution_order: int = 1 + + +class SharedTracer(Singleton, BaseTracer, ABC): + """A thread-safe Singleton implementation of BaseTracer.""" + + _tracer_stack = TracerStack() + _tracer_session = None + + @property + def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: + """Get the tracer stack.""" + return self._tracer_stack.stack + + @property + def _execution_order(self) -> int: + """Get the execution order for a run.""" + return self._tracer_stack.execution_order + + @_execution_order.setter + def _execution_order(self, value: int) -> None: + """Set the execution order for a run.""" + self._tracer_stack.execution_order = value + + @property + def _session(self) -> Optional[TracerSession]: + """Get the tracing session.""" + return self._tracer_session + + @_session.setter + def _session(self, value: TracerSession) -> None: + """Set the tracing session.""" + with self._lock: + # TODO: currently, we are only checking current thread's stack. + # Need to make sure that we are not in the middle of a trace + # in any thread. + if self._stack: + raise TracerException( + "Cannot set a session while a trace is being recorded" + ) + self._tracer_session = value diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py new file mode 100644 index 00000000..d2502204 --- /dev/null +++ b/langchain/callbacks/tracers/langchain.py @@ -0,0 +1,112 @@ +"""A Tracer implementation that records to LangChain endpoint.""" +from __future__ import annotations + +import logging +import os +from abc import ABC +from typing import Any, Dict, Optional, Union + +import requests + +from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.schemas import ( + ChainRun, + LLMRun, + ToolRun, + TracerSession, + TracerSessionCreate, +) + + +class BaseLangChainTracer(BaseTracer, ABC): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + always_verbose: bool = True + _endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + _headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + _headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + + def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, LLMRun): + endpoint = f"{self._endpoint}/llm-runs" + elif isinstance(run, ChainRun): + endpoint = f"{self._endpoint}/chain-runs" + else: + endpoint = f"{self._endpoint}/tool-runs" + + try: + requests.post( + endpoint, + data=run.json(), + headers=self._headers, + ) + except Exception as e: + logging.warning(f"Failed to persist run: {e}") + + def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession: + """Persist a session.""" + try: + r = requests.post( + f"{self._endpoint}/sessions", + data=session_create.json(), + headers=self._headers, + ) + session = TracerSession(id=r.json()["id"], **session_create.dict()) + except Exception as e: + logging.warning(f"Failed to create session, using default session: {e}") + session = TracerSession(id=1, **session_create.dict()) + return session + + def load_session(self, session_name: str) -> TracerSession: + """Load a session from the tracer.""" + try: + r = requests.get( + f"{self._endpoint}/sessions?name={session_name}", + headers=self._headers, + ) + tracer_session = TracerSession(**r.json()[0]) + self._session = tracer_session + return tracer_session + except Exception as e: + logging.warning( + f"Failed to load session {session_name}, using empty session: {e}" + ) + tracer_session = TracerSession(id=1) + self._session = tracer_session + return tracer_session + + def load_default_session(self) -> TracerSession: + """Load the default tracing session and set it as the Tracer's session.""" + try: + r = requests.get( + f"{self._endpoint}/sessions", + headers=self._headers, + ) + # Use the first session result + tracer_session = TracerSession(**r.json()[0]) + self._session = tracer_session + return tracer_session + except Exception as e: + logging.warning(f"Failed to default session, using empty session: {e}") + tracer_session = TracerSession(id=1) + self._session = tracer_session + return tracer_session + + def _add_child_run( + self, + parent_run: Union[ChainRun, ToolRun], + child_run: Union[LLMRun, ChainRun, ToolRun], + ) -> None: + """Add child run to a chain run or tool run.""" + if isinstance(child_run, LLMRun): + parent_run.child_llm_runs.append(child_run) + elif isinstance(child_run, ChainRun): + parent_run.child_chain_runs.append(child_run) + else: + parent_run.child_tool_runs.append(child_run) + + def _generate_id(self) -> Optional[Union[int, str]]: + """Generate an id for a run.""" + return None diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py new file mode 100644 index 00000000..bb77d747 --- /dev/null +++ b/langchain/callbacks/tracers/schemas.py @@ -0,0 +1,76 @@ +"""Schemas for tracers.""" +from __future__ import annotations + +import datetime +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from langchain.schema import LLMResult + + +class TracerSessionBase(BaseModel): + """Base class for TracerSession.""" + + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + name: Optional[str] = None + extra: Optional[Dict[str, Any]] = None + + +class TracerSessionCreate(TracerSessionBase): + """Create class for TracerSession.""" + + pass + + +class TracerSession(TracerSessionBase): + """TracerSession schema.""" + + id: int + + +class BaseRun(BaseModel): + """Base class for Run.""" + + id: Optional[Union[int, str]] = None + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + extra: Optional[Dict[str, Any]] = None + execution_order: int + serialized: Dict[str, Any] + session_id: int + error: Optional[str] = None + + +class LLMRun(BaseRun): + """Class for LLMRun.""" + + prompts: List[str] + response: Optional[LLMResult] = None + + +class ChainRun(BaseRun): + """Class for ChainRun.""" + + inputs: Dict[str, Any] + outputs: Optional[Dict[str, Any]] = None + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list) + + +class ToolRun(BaseRun): + """Class for ToolRun.""" + + tool_input: str + output: Optional[str] = None + action: str + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list) + + +ChainRun.update_forward_refs() +ToolRun.update_forward_refs() diff --git a/langchain/chains/base.py b/langchain/chains/base.py index e30cdb15..f2191859 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -150,7 +150,7 @@ class Chain(BaseModel, ABC): ) try: outputs = self._call(inputs) - except Exception as e: + except (KeyboardInterrupt, Exception) as e: self.callback_manager.on_chain_error(e, verbose=self.verbose) raise e self.callback_manager.on_chain_end(outputs, verbose=self.verbose) diff --git a/langchain/docker-compose.yaml b/langchain/docker-compose.yaml new file mode 100644 index 00000000..d1558cdb --- /dev/null +++ b/langchain/docker-compose.yaml @@ -0,0 +1,29 @@ +version: '3' +services: + langchain-frontend: + image: notlangchain/langchainplus-frontend:latest + ports: + - 4173:4173 + environment: + - BACKEND_URL=http://langchain-backend:8000 + - PUBLIC_BASE_URL=http://localhost:8000 + - PUBLIC_DEV_MODE=true + depends_on: + - langchain-backend + langchain-backend: + image: notlangchain/langchainplus:latest + environment: + - PORT=8000 + - LANGCHAIN_ENV=local + ports: + - 8000:8000 + depends_on: + - langchain-db + langchain-db: + image: postgres:14.1 + environment: + - POSTGRES_PASSWORD=postgres + - POSTGRES_USER=postgres + - POSTGRES_DB=postgres + ports: + - 5432:5432 diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 59eccc34..b9ddf50f 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -74,7 +74,7 @@ class BaseLLM(BaseModel, ABC): ) try: output = self._generate(prompts, stop=stop) - except Exception as e: + except (KeyboardInterrupt, Exception) as e: self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e self.callback_manager.on_llm_end(output, verbose=self.verbose) @@ -97,7 +97,7 @@ class BaseLLM(BaseModel, ABC): ) try: new_results = self._generate(missing_prompts, stop=stop) - except Exception as e: + except (KeyboardInterrupt, Exception) as e: self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e self.callback_manager.on_llm_end(new_results, verbose=self.verbose) diff --git a/langchain/schema.py b/langchain/schema.py index 6bb53eb5..9938f08e 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,7 +1,10 @@ """Common schema objects.""" +from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional +from dataclasses_json import dataclass_json + class AgentAction(NamedTuple): """Agent's action to take.""" @@ -18,7 +21,9 @@ class AgentFinish(NamedTuple): log: str -class Generation(NamedTuple): +@dataclass_json +@dataclass +class Generation: """Output of a single generation.""" text: str @@ -30,7 +35,9 @@ class Generation(NamedTuple): # TODO: add log probs -class LLMResult(NamedTuple): +@dataclass_json +@dataclass +class LLMResult: """Class that contains all relevant information for an LLM Result.""" generations: List[List[Generation]] diff --git a/langchain/server.py b/langchain/server.py new file mode 100644 index 00000000..4b00a478 --- /dev/null +++ b/langchain/server.py @@ -0,0 +1,14 @@ +"""Script to run langchain-server locally using docker-compose.""" +import subprocess +from pathlib import Path + + +def main() -> None: + """Run the langchain server locally.""" + p = Path(__file__).absolute().parent / "docker-compose.yaml" + subprocess.run(["docker-compose", "-f", str(p), "pull"]) + subprocess.run(["docker-compose", "-f", str(p), "up"]) + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index d69b36d7..1e3ffe4e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -784,6 +784,26 @@ files = [ {file = "cymem-2.0.7.tar.gz", hash = "sha256:e6034badb5dd4e10344211c81f16505a55553a7164adc314c75bd80cf07e57a8"}, ] +[[package]] +name = "dataclasses-json" +version = "0.5.7" +description = "Easily serialize dataclasses to and from JSON" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "dataclasses-json-0.5.7.tar.gz", hash = "sha256:c2c11bc8214fbf709ffc369d11446ff6945254a7f09128154a7620613d8fda90"}, + {file = "dataclasses_json-0.5.7-py3-none-any.whl", hash = "sha256:bc285b5f892094c3a53d558858a88553dd6a61a11ab1a8128a0e554385dcc5dd"}, +] + +[package.dependencies] +marshmallow = ">=3.3.0,<4.0.0" +marshmallow-enum = ">=1.5.1,<2.0.0" +typing-inspect = ">=0.4.0" + +[package.extras] +dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest (>=6.2.3)", "simplejson", "types-dataclasses"] + [[package]] name = "debugpy" version = "1.6.5" @@ -1152,6 +1172,21 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "freezegun" +version = "1.2.2" +description = "Let your Python tests travel through time" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "freezegun-1.2.2-py3-none-any.whl", hash = "sha256:ea1b963b993cb9ea195adbd893a48d573fda951b0da64f60883d7e988b606c9f"}, + {file = "freezegun-1.2.2.tar.gz", hash = "sha256:cd22d1ba06941384410cd967d8a99d5ae2442f57dfafeff2fda5de8dc5c05446"}, +] + +[package.dependencies] +python-dateutil = ">=2.7" + [[package]] name = "google-api-core" version = "2.11.0" @@ -2235,7 +2270,6 @@ files = [ {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"}, {file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"}, - {file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"}, {file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"}, @@ -2245,7 +2279,6 @@ files = [ {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"}, {file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"}, - {file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"}, {file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"}, @@ -2404,6 +2437,42 @@ files = [ {file = "MarkupSafe-2.1.2.tar.gz", hash = "sha256:abcabc8c2b26036d62d4c746381a6f7cf60aafcc653198ad678306986b09450d"}, ] +[[package]] +name = "marshmallow" +version = "3.19.0" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "marshmallow-3.19.0-py3-none-any.whl", hash = "sha256:93f0958568da045b0021ec6aeb7ac37c81bfcccbb9a0e7ed8559885070b3a19b"}, + {file = "marshmallow-3.19.0.tar.gz", hash = "sha256:90032c0fd650ce94b6ec6dc8dfeb0e3ff50c144586462c389b81a07205bedb78"}, +] + +[package.dependencies] +packaging = ">=17.0" + +[package.extras] +dev = ["flake8 (==5.0.4)", "flake8-bugbear (==22.10.25)", "mypy (==0.990)", "pre-commit (>=2.4,<3.0)", "pytest", "pytz", "simplejson", "tox"] +docs = ["alabaster (==0.7.12)", "autodocsumm (==0.2.9)", "sphinx (==5.3.0)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"] +lint = ["flake8 (==5.0.4)", "flake8-bugbear (==22.10.25)", "mypy (==0.990)", "pre-commit (>=2.4,<3.0)"] +tests = ["pytest", "pytz", "simplejson"] + +[[package]] +name = "marshmallow-enum" +version = "1.5.1" +description = "Enum field for Marshmallow" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "marshmallow-enum-1.5.1.tar.gz", hash = "sha256:38e697e11f45a8e64b4a1e664000897c659b60aa57bfa18d44e226a9920b6e58"}, + {file = "marshmallow_enum-1.5.1-py2.py3-none-any.whl", hash = "sha256:57161ab3dbfde4f57adeb12090f39592e992b9c86d206d02f6bd03ebec60f072"}, +] + +[package.dependencies] +marshmallow = ">=2.0.0" + [[package]] name = "matplotlib-inline" version = "0.1.6" @@ -2580,7 +2649,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." -category = "dev" +category = "main" optional = false python-versions = "*" files = [ @@ -4941,14 +5010,11 @@ files = [ {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, - {file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"}, - {file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, - {file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, @@ -4957,7 +5023,6 @@ files = [ {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, - {file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, @@ -5284,6 +5349,22 @@ files = [ {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] +[[package]] +name = "typing-inspect" +version = "0.8.0" +description = "Runtime inspection utilities for typing module." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.8.0-py3-none-any.whl", hash = "sha256:5fbf9c1e65d4fa01e701fe12a5bca6c6e08a4ffd5bc60bfac028253a447c5188"}, + {file = "typing_inspect-0.8.0.tar.gz", hash = "sha256:8b1ff0c400943b6145df8119c41c244ca8207f1f10c9c057aeed1560e4806e3d"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "uri-template" version = "1.2.0" @@ -5585,4 +5666,4 @@ llms = ["manifest-ml", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2da2c13a9a1572e4f7792c336c57f7087d79c21a1370d6f0534bcfa80abcad30" +content-hash = "537ede877b299a8800eb26a428607be97b59f89435259abda2c7cc86092306c6" diff --git a/pyproject.toml b/pyproject.toml index 97553d55..e0531f32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,9 @@ license = "MIT" readme = "README.md" repository = "https://www.github.com/hwchase17/langchain" +[tool.poetry.scripts] +langchain-server = "langchain.server:main" + [tool.poetry.dependencies] python = ">=3.8.1,<4.0" pydantic = "^1" @@ -31,6 +34,7 @@ weaviate-client = {version = "^3", optional = true} google-api-python-client = {version = "2.70.0", optional = true} wolframalpha = {version = "5.0.0", optional = true} qdrant-client = {version = "^0.11.7", optional = true} +dataclasses-json = "^0.5.7" [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" @@ -52,6 +56,7 @@ pytest-cov = "^4.0.0" pytest-dotenv = "^0.5.2" duckdb-engine = "^0.6.6" pytest-watcher = "^0.2.6" +freezegun = "^1.2.2" [tool.poetry.group.lint.dependencies] flake8-docstrings = "^1.6.0" @@ -59,7 +64,7 @@ black = "^22.10.0" isort = "^5.10.1" flake8 = "^6.0.0" types-toml = "^0.10.8.1" -types-redis = "^4.3.21.6" +types-redis = "^4.3.21.6" [tool.poetry.group.typing.dependencies] mypy = "^0.991" diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 896b05aa..36f32b85 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -1,17 +1,43 @@ """A fake callback handler for testing purposes.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Union + +from pydantic import BaseModel from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult -class FakeCallbackHandler(BaseCallbackHandler): +class FakeCallbackHandler(BaseModel, BaseCallbackHandler): """Fake callback handler for testing.""" starts: int = 0 ends: int = 0 errors: int = 0 text: int = 0 + ignore_llm_: bool = False + ignore_chain_: bool = False + ignore_agent_: bool = False + always_verbose_: bool = False + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return self.always_verbose_ + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any @@ -23,7 +49,9 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when LLM ends running.""" self.ends += 1 - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when LLM errors.""" self.errors += 1 @@ -37,7 +65,9 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when chain ends running.""" self.ends += 1 - def on_chain_error(self, error: Exception, **kwargs: Any) -> None: + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when chain errors.""" self.errors += 1 @@ -51,7 +81,9 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when tool ends running.""" self.ends += 1 - def on_tool_error(self, error: Exception, **kwargs: Any) -> None: + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: """Run when tool errors.""" self.errors += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 03d0181b..acca2171 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -8,6 +8,31 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler def _test_callback_manager( manager: BaseCallbackManager, *handlers: FakeCallbackHandler +) -> None: + """Test the CallbackManager.""" + manager.on_llm_start({}, []) + manager.on_llm_end(LLMResult(generations=[])) + manager.on_llm_error(Exception()) + manager.on_chain_start({"name": "foo"}, {}) + manager.on_chain_end({}) + manager.on_chain_error(Exception()) + manager.on_tool_start({}, AgentAction("", "", "")) + manager.on_tool_end("") + manager.on_tool_error(Exception()) + manager.on_agent_finish(AgentFinish(log="", return_values={})) + for handler in handlers: + if handler.always_verbose: + assert handler.starts == 3 + assert handler.ends == 4 + assert handler.errors == 3 + else: + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0 + + +def _test_callback_manager_pass_in_verbose( + manager: BaseCallbackManager, *handlers: FakeCallbackHandler ) -> None: """Test the CallbackManager.""" manager.on_llm_start({}, [], verbose=True) @@ -19,7 +44,7 @@ def _test_callback_manager( manager.on_tool_start({}, AgentAction("", "", ""), verbose=True) manager.on_tool_end("", verbose=True) manager.on_tool_error(Exception(), verbose=True) - manager.on_agent_finish(AgentFinish({}, ""), verbose=True) + manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True) for handler in handlers: assert handler.starts == 3 assert handler.ends == 4 @@ -27,17 +52,25 @@ def _test_callback_manager( def test_callback_manager() -> None: + """Test the CallbackManager.""" + handler1 = FakeCallbackHandler(always_verbose_=True) + handler2 = FakeCallbackHandler(always_verbose_=False) + manager = CallbackManager([handler1, handler2]) + _test_callback_manager(manager, handler1, handler2) + + +def test_callback_manager_pass_in_verbose() -> None: """Test the CallbackManager.""" handler1 = FakeCallbackHandler() handler2 = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler1, handler2]) - _test_callback_manager(manager, handler1, handler2) + manager = CallbackManager([handler1, handler2]) + _test_callback_manager_pass_in_verbose(manager, handler1, handler2) def test_ignore_llm() -> None: """Test ignore llm param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_llm=True) - handler2 = FakeCallbackHandler() + handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True) + handler2 = FakeCallbackHandler(always_verbose_=True) manager = CallbackManager(handlers=[handler1, handler2]) manager.on_llm_start({}, [], verbose=True) manager.on_llm_end(LLMResult(generations=[]), verbose=True) @@ -52,8 +85,8 @@ def test_ignore_llm() -> None: def test_ignore_chain() -> None: """Test ignore chain param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_chain=True) - handler2 = FakeCallbackHandler() + handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True) + handler2 = FakeCallbackHandler(always_verbose_=True) manager = CallbackManager(handlers=[handler1, handler2]) manager.on_chain_start({"name": "foo"}, {}, verbose=True) manager.on_chain_end({}, verbose=True) @@ -68,8 +101,8 @@ def test_ignore_chain() -> None: def test_ignore_agent() -> None: """Test ignore agent param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_agent=True) - handler2 = FakeCallbackHandler() + handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True) + handler2 = FakeCallbackHandler(always_verbose_=True) manager = CallbackManager(handlers=[handler1, handler2]) manager.on_tool_start({}, AgentAction("", "", ""), verbose=True) manager.on_tool_end("", verbose=True) @@ -90,7 +123,7 @@ def test_shared_callback_manager() -> None: assert manager1 is manager2 - handler1 = FakeCallbackHandler() + handler1 = FakeCallbackHandler(always_verbose_=True) handler2 = FakeCallbackHandler() manager1.add_handler(handler1) manager2.add_handler(handler2) diff --git a/tests/unit_tests/callbacks/tracers/__init__.py b/tests/unit_tests/callbacks/tracers/__init__.py new file mode 100644 index 00000000..bb6b0428 --- /dev/null +++ b/tests/unit_tests/callbacks/tracers/__init__.py @@ -0,0 +1 @@ +"""Tests for correct functioning of tracers.""" diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py new file mode 100644 index 00000000..625d4e8f --- /dev/null +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -0,0 +1,530 @@ +"""Test Tracer classes.""" +from __future__ import annotations + +import threading +from datetime import datetime +from typing import List, Optional, Union + +import pytest +from freezegun import freeze_time + +from langchain.callbacks.tracers.base import ( + BaseTracer, + ChainRun, + LLMRun, + SharedTracer, + ToolRun, + Tracer, + TracerException, + TracerSession, +) +from langchain.callbacks.tracers.schemas import TracerSessionCreate +from langchain.schema import AgentAction, LLMResult + +TEST_SESSION_ID = 2023 + + +@freeze_time("2023-01-01") +def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: + return ChainRun( + id=None, + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + child_runs=[ + ToolRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + serialized={}, + tool_input="test", + output="test", + action="action", + session_id=TEST_SESSION_ID, + error=None, + child_runs=[ + LLMRun( + id=None, + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + serialized={}, + prompts=[], + response=LLMResult([[]]), + session_id=TEST_SESSION_ID, + ) + ], + ), + LLMRun( + id=None, + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + serialized={}, + prompts=[], + response=LLMResult([[]]), + session_id=TEST_SESSION_ID, + ), + ], + ) + + +def _perform_nested_run(tracer: BaseTracer) -> None: + """Perform a nested run.""" + tracer.on_chain_start(serialized={}, inputs={}) + tracer.on_tool_start( + serialized={}, action=AgentAction(tool="action", tool_input="test", log="") + ) + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_end(response=LLMResult([[]])) + tracer.on_tool_end("test") + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_end(response=LLMResult([[]])) + tracer.on_chain_end(outputs={}) + + +def _add_child_run( + parent_run: Union[ChainRun, ToolRun], + child_run: Union[LLMRun, ChainRun, ToolRun], +) -> None: + """Add child run to a chain run or tool run.""" + parent_run.child_runs.append(child_run) + + +def _generate_id() -> Optional[Union[int, str]]: + """Generate an id for a run.""" + return None + + +def load_session(session_name: str) -> TracerSession: + """Load a tracing session.""" + return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) + + +def _persist_session(session: TracerSessionCreate) -> TracerSession: + """Persist a tracing session.""" + return TracerSession(id=TEST_SESSION_ID, **session.dict()) + + +def load_default_session() -> TracerSession: + """Load a tracing session.""" + return TracerSession(id=1, name="default", start_time=datetime.utcnow()) + + +class FakeTracer(Tracer): + """Fake tracer that records LangChain execution.""" + + def __init__(self) -> None: + """Initialize the tracer.""" + super().__init__() + self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] + + def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + self.runs.append(run) + + def _add_child_run( + self, + parent_run: Union[ChainRun, ToolRun], + child_run: Union[LLMRun, ChainRun, ToolRun], + ) -> None: + """Add child run to a chain run or tool run.""" + _add_child_run(parent_run, child_run) + + def _generate_id(self) -> Optional[Union[int, str]]: + """Generate an id for a run.""" + return _generate_id() + + def _persist_session(self, session: TracerSessionCreate) -> TracerSession: + """Persist a tracing session.""" + return _persist_session(session) + + def load_session(self, session_name: str) -> TracerSession: + """Load a tracing session.""" + return load_session(session_name) + + def load_default_session(self) -> TracerSession: + """Load a tracing session.""" + return load_default_session() + + +class FakeSharedTracer(SharedTracer): + """Fake shared tracer that records LangChain execution.""" + + runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] + + def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + with self._lock: + self.runs.append(run) + + def remove_runs(self) -> None: + """Remove all runs.""" + with self._lock: + self.runs = [] + + def _add_child_run( + self, + parent_run: Union[ChainRun, ToolRun], + child_run: Union[LLMRun, ChainRun, ToolRun], + ) -> None: + """Add child run to a chain run or tool run.""" + _add_child_run(parent_run, child_run) + + def _generate_id(self) -> Optional[Union[int, str]]: + """Generate an id for a run.""" + return _generate_id() + + def _persist_session(self, session: TracerSessionCreate) -> TracerSession: + """Persist a tracing session.""" + return _persist_session(session) + + def load_session(self, session_name: str) -> TracerSession: + """Load a tracing session.""" + return load_session(session_name) + + def load_default_session(self) -> TracerSession: + """Load a tracing session.""" + return load_default_session() + + +@freeze_time("2023-01-01") +def test_tracer_llm_run() -> None: + """Test tracer on an LLM run.""" + compare_run = LLMRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + prompts=[], + response=LLMResult([[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_end(response=LLMResult([[]])) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_errors_no_session() -> None: + """Test tracer on an LLM run without a session.""" + tracer = FakeTracer() + + with pytest.raises(TracerException): + tracer.on_llm_start(serialized={}, prompts=[]) + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_errors_no_start() -> None: + """Test tracer on an LLM run without a start.""" + tracer = FakeTracer() + + tracer.new_session() + with pytest.raises(TracerException): + tracer.on_llm_end(response=LLMResult([[]])) + + +@freeze_time("2023-01-01") +def test_tracer_multiple_llm_runs() -> None: + """Test the tracer with multiple runs.""" + compare_run = LLMRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + prompts=[], + response=LLMResult([[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + num_runs = 10 + for _ in range(num_runs): + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_end(response=LLMResult([[]])) + + assert tracer.runs == [compare_run] * num_runs + + +@freeze_time("2023-01-01") +def test_tracer_chain_run() -> None: + """Test tracer on a Chain run.""" + compare_run = ChainRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_chain_start(serialized={}, inputs={}) + tracer.on_chain_end(outputs={}) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run() -> None: + """Test tracer on a Tool run.""" + compare_run = ToolRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + tool_input="test", + output="test", + action="action", + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_tool_start( + serialized={}, action=AgentAction(tool="action", tool_input="test", log="") + ) + tracer.on_tool_end("test") + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_run() -> None: + """Test tracer on a nested run.""" + tracer = FakeTracer() + tracer.new_session() + _perform_nested_run(tracer) + assert tracer.runs == [_get_compare_run()] + + +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error() -> None: + """Test tracer on an LLM run with an error.""" + exception = Exception("test") + + compare_run = LLMRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + prompts=[], + response=None, + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_error(exception) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_chain_run_on_error() -> None: + """Test tracer on a Chain run with an error.""" + exception = Exception("test") + + compare_run = ChainRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + inputs={}, + outputs=None, + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_chain_start(serialized={}, inputs={}) + tracer.on_chain_error(exception) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_tool_run_on_error() -> None: + """Test tracer on a Tool run with an error.""" + exception = Exception("test") + + compare_run = ToolRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + tool_input="test", + output=None, + action="action", + session_id=TEST_SESSION_ID, + error=repr(exception), + ) + tracer = FakeTracer() + + tracer.new_session() + tracer.on_tool_start( + serialized={}, action=AgentAction(tool="action", tool_input="test", log="") + ) + tracer.on_tool_error(exception) + assert tracer.runs == [compare_run] + + +@freeze_time("2023-01-01") +def test_tracer_nested_runs_on_error() -> None: + """Test tracer on a nested run with an error.""" + exception = Exception("test") + + tracer = FakeTracer() + tracer.new_session() + + for _ in range(3): + tracer.on_chain_start(serialized={}, inputs={}) + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_end(response=LLMResult([[]])) + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_end(response=LLMResult([[]])) + tracer.on_tool_start( + serialized={}, action=AgentAction(tool="action", tool_input="test", log="") + ) + tracer.on_llm_start(serialized={}, prompts=[]) + tracer.on_llm_error(exception) + tracer.on_tool_error(exception) + tracer.on_chain_error(exception) + + compare_run = ChainRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + serialized={}, + session_id=TEST_SESSION_ID, + error=repr(exception), + inputs={}, + outputs=None, + child_runs=[ + LLMRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + serialized={}, + session_id=TEST_SESSION_ID, + error=None, + prompts=[], + response=LLMResult(generations=[[]], llm_output=None), + ), + LLMRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + serialized={}, + session_id=TEST_SESSION_ID, + error=None, + prompts=[], + response=LLMResult(generations=[[]], llm_output=None), + ), + ToolRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + serialized={}, + session_id=TEST_SESSION_ID, + error=repr(exception), + tool_input="test", + output=None, + action="action", + child_runs=[ + LLMRun( + id=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=5, + serialized={}, + session_id=TEST_SESSION_ID, + error=repr(exception), + prompts=[], + response=None, + ) + ], + child_llm_runs=[], + child_chain_runs=[], + child_tool_runs=[], + ), + ], + child_llm_runs=[], + child_chain_runs=[], + child_tool_runs=[], + ) + + assert tracer.runs == [compare_run] * 3 + + +@freeze_time("2023-01-01") +def test_shared_tracer_nested_run() -> None: + """Test shared tracer on a nested run.""" + tracer = FakeSharedTracer() + tracer.new_session() + tracer.remove_runs() + _perform_nested_run(tracer) + assert tracer.runs == [_get_compare_run()] + + +@freeze_time("2023-01-01") +def test_shared_tracer_nested_run_multithreaded() -> None: + """Test shared tracer on a nested run.""" + tracer = FakeSharedTracer() + tracer.remove_runs() + tracer.new_session() + threads = [] + num_threads = 10 + for _ in range(num_threads): + thread = threading.Thread(target=_perform_nested_run, args=(tracer,)) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + + assert tracer.runs == [_get_compare_run()] * num_threads