From 46b31626b5ee5fc40f5098404db3a15400471fce Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Thu, 29 Dec 2022 15:11:37 -0500 Subject: [PATCH] Add BaseCallbackHandler and CallbackManager (#476) --- langchain/agents/agent.py | 5 + langchain/callbacks/__init__.py | 8 + langchain/callbacks/base.py | 137 ++++++++++++++++++ langchain/callbacks/shared.py | 101 +++++++++++++ langchain/chains/base.py | 3 + langchain/chains/llm.py | 3 +- langchain/llms/base.py | 26 ++-- langchain/llms/openai.py | 4 +- langchain/schema.py | 12 +- tests/unit_tests/callbacks/__init__.py | 1 + .../callbacks/test_callback_manager.py | 107 ++++++++++++++ tests/unit_tests/test_hyde.py | 4 +- 12 files changed, 392 insertions(+), 19 deletions(-) create mode 100644 langchain/callbacks/__init__.py create mode 100644 langchain/callbacks/base.py create mode 100644 langchain/callbacks/shared.py create mode 100644 tests/unit_tests/callbacks/__init__.py create mode 100644 tests/unit_tests/callbacks/test_callback_manager.py diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 9385e7ca..2c95b2ed 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, root_validator import langchain from langchain.agents.tools import Tool +from langchain.callbacks import get_callback_manager from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping @@ -213,8 +214,12 @@ class AgentExecutor(Chain, BaseModel): # And then we lookup the tool if output.tool in name_to_tool_map: chain = name_to_tool_map[output.tool] + get_callback_manager().on_tool_start( + {"name": str(chain)[:60] + "..."}, output.tool, output.tool_input + ) # We then call the tool on the tool input to get an observation observation = chain(output.tool_input) + get_callback_manager().on_tool_end(observation) color = color_mapping[output.tool] else: observation = f"{output.tool} is not a valid tool, try another one." diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py new file mode 100644 index 00000000..eef25d7e --- /dev/null +++ b/langchain/callbacks/__init__.py @@ -0,0 +1,8 @@ +"""Callback handlers that allow listening to events in LangChain.""" +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.shared import SharedCallbackManager + + +def get_callback_manager() -> BaseCallbackManager: + """Return the shared callback manager.""" + return SharedCallbackManager() diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py new file mode 100644 index 00000000..7ce41f3a --- /dev/null +++ b/langchain/callbacks/base.py @@ -0,0 +1,137 @@ +"""Base callback handler that can be used to handle callbacks from langchain.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from langchain.schema import LLMResult + + +class BaseCallbackHandler(ABC): + """Base callback handler that can be used to handle callbacks from langchain.""" + + @abstractmethod + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **extra: str + ) -> None: + """Run when LLM starts running.""" + + @abstractmethod + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + + @abstractmethod + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + + @abstractmethod + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + ) -> None: + """Run when chain starts running.""" + + @abstractmethod + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + + @abstractmethod + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + + @abstractmethod + def on_tool_start( + self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + ) -> None: + """Run when tool starts running.""" + + @abstractmethod + def on_tool_end(self, output: str) -> None: + """Run when tool ends running.""" + + @abstractmethod + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + + +class BaseCallbackManager(BaseCallbackHandler, ABC): + """Base callback manager that can be used to handle callbacks from LangChain.""" + + @abstractmethod + def add_handler(self, callback: BaseCallbackHandler) -> None: + """Add a handler to the callback manager.""" + + @abstractmethod + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + + +class CallbackManager(BaseCallbackManager): + """Callback manager that can be used to handle callbacks from langchain.""" + + def __init__(self, handlers: List[BaseCallbackHandler]) -> None: + """Initialize the callback manager.""" + self.handlers = handlers + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **extra: str + ) -> None: + """Run when LLM starts running.""" + for handler in self.handlers: + handler.on_llm_start(serialized, prompts, **extra) + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + for handler in self.handlers: + handler.on_llm_end(response) + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + for handler in self.handlers: + handler.on_llm_error(error) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + ) -> None: + """Run when chain starts running.""" + for handler in self.handlers: + handler.on_chain_start(serialized, inputs, **extra) + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + for handler in self.handlers: + handler.on_chain_end(outputs) + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + for handler in self.handlers: + handler.on_chain_error(error) + + def on_tool_start( + self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + ) -> None: + """Run when tool starts running.""" + for handler in self.handlers: + handler.on_tool_start(serialized, action, tool_input, **extra) + + def on_tool_end(self, output: str) -> None: + """Run when tool ends running.""" + for handler in self.handlers: + handler.on_tool_end(output) + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + for handler in self.handlers: + handler.on_tool_error(error) + + def add_handler(self, handler: BaseCallbackHandler) -> None: + """Add a handler to the callback manager.""" + self.handlers.append(handler) + + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + self.handlers.remove(handler) diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py new file mode 100644 index 00000000..b285a394 --- /dev/null +++ b/langchain/callbacks/shared.py @@ -0,0 +1,101 @@ +"""A shared CallbackManager.""" + +import threading +from typing import Any, Dict, List + +from langchain.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + CallbackManager, +) +from langchain.schema import LLMResult + + +class Singleton: + """A thread-safe singleton class that can be inherited from.""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls) -> Any: + """Create a new shared instance of the class.""" + if cls._instance is None: + with cls._lock: + # Another thread could have created the instance + # before we acquired the lock. So check that the + # instance is still nonexistent. + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + +class SharedCallbackManager(Singleton, BaseCallbackManager): + """A thread-safe singleton CallbackManager.""" + + _callback_manager: CallbackManager = CallbackManager([]) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **extra: str + ) -> None: + """Run when LLM starts running.""" + with self._lock: + self._callback_manager.on_llm_start(serialized, prompts, **extra) + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + with self._lock: + self._callback_manager.on_llm_end(response) + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + with self._lock: + self._callback_manager.on_llm_error(error) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + ) -> None: + """Run when chain starts running.""" + with self._lock: + self._callback_manager.on_chain_start(serialized, inputs, **extra) + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + with self._lock: + self._callback_manager.on_chain_end(outputs) + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + with self._lock: + self._callback_manager.on_chain_error(error) + + def on_tool_start( + self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + ) -> None: + """Run when tool starts running.""" + with self._lock: + self._callback_manager.on_tool_start( + serialized, action, tool_input, **extra + ) + + def on_tool_end(self, output: str) -> None: + """Run when tool ends running.""" + with self._lock: + self._callback_manager.on_tool_end(output) + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + with self._lock: + self._callback_manager.on_tool_error(error) + + def add_handler(self, callback: BaseCallbackHandler) -> None: + """Add a callback to the callback manager.""" + with self._lock: + self._callback_manager.add_handler(callback) + + def remove_handler(self, callback: BaseCallbackHandler) -> None: + """Remove a callback from the callback manager.""" + with self._lock: + self._callback_manager.remove_handler(callback) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 24f90ea0..c4ce68d6 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Extra, Field import langchain +from langchain.callbacks import get_callback_manager class Memory(BaseModel, ABC): @@ -109,7 +110,9 @@ class Chain(BaseModel, ABC): print( f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" ) + get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs) outputs = self._call(inputs) + get_callback_manager().on_chain_end(outputs) if self.verbose: print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") self._validate_outputs(outputs) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index cbe01bd9..72c1f3a0 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,8 +5,9 @@ from pydantic import BaseModel, Extra import langchain from langchain.chains.base import Chain -from langchain.llms.base import BaseLLM, LLMResult +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import LLMResult class LLMChain(Chain, BaseModel): diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 1941385a..1a1ca392 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -2,23 +2,14 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import yaml from pydantic import BaseModel, Extra import langchain -from langchain.schema import Generation - - -class LLMResult(NamedTuple): - """Class that contains all relevant information for an LLM Result.""" - - generations: List[List[Generation]] - """List of the things generated. This is List[List[]] because - each input could have multiple generations.""" - llm_output: Optional[dict] = None - """For arbitrary LLM provider specific output.""" +from langchain.callbacks import get_callback_manager +from langchain.schema import Generation, LLMResult class BaseLLM(BaseModel, ABC): @@ -48,7 +39,12 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - return self._generate(prompts, stop=stop) + get_callback_manager().on_llm_start( + {"name": self.__class__.__name__}, prompts + ) + output = self._generate(prompts, stop=stop) + get_callback_manager().on_llm_end(output) + return output params = self._llm_dict() params["stop"] = stop llm_string = str(sorted([(k, v) for k, v in params.items()])) @@ -62,7 +58,11 @@ class BaseLLM(BaseModel, ABC): else: missing_prompts.append(prompt) missing_prompt_idxs.append(i) + get_callback_manager().on_llm_start( + {"name": self.__class__.__name__}, missing_prompts + ) new_results = self._generate(missing_prompts, stop=stop) + get_callback_manager().on_llm_end(new_results) for i, result in enumerate(new_results.generations): existing_prompts[i] = result prompt = prompts[i] diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 42ca39b5..2995935e 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union from pydantic import BaseModel, Extra, Field, root_validator -from langchain.llms.base import BaseLLM, LLMResult -from langchain.schema import Generation +from langchain.llms.base import BaseLLM +from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env diff --git a/langchain/schema.py b/langchain/schema.py index 31cf1cdc..a4b4e626 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,6 +1,6 @@ """Common schema objects.""" -from typing import NamedTuple +from typing import List, NamedTuple, Optional class AgentAction(NamedTuple): @@ -24,3 +24,13 @@ class Generation(NamedTuple): text: str """Generated text output.""" # TODO: add log probs + + +class LLMResult(NamedTuple): + """Class that contains all relevant information for an LLM Result.""" + + generations: List[List[Generation]] + """List of the things generated. This is List[List[]] because + each input could have multiple generations.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" diff --git a/tests/unit_tests/callbacks/__init__.py b/tests/unit_tests/callbacks/__init__.py new file mode 100644 index 00000000..cd34752b --- /dev/null +++ b/tests/unit_tests/callbacks/__init__.py @@ -0,0 +1 @@ +"""Tests for correct functioning of callbacks.""" diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py new file mode 100644 index 00000000..fcab8dda --- /dev/null +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -0,0 +1,107 @@ +"""Test CallbackManager.""" + +from typing import Any, Dict, List + +from langchain.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + CallbackManager, +) +from langchain.callbacks.shared import SharedCallbackManager +from langchain.schema import LLMResult + + +class FakeCallbackHandler(BaseCallbackHandler): + """Fake callback handler for testing.""" + + def __init__(self) -> None: + """Initialize the mock callback handler.""" + self.starts = 0 + self.ends = 0 + self.errors = 0 + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **extra: str + ) -> None: + """Run when LLM starts running.""" + self.starts += 1 + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + self.ends += 1 + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + ) -> None: + """Run when chain starts running.""" + self.starts += 1 + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + self.ends += 1 + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + ) -> None: + """Run when tool starts running.""" + self.starts += 1 + + def on_tool_end(self, output: str) -> None: + """Run when tool ends running.""" + self.ends += 1 + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + self.errors += 1 + + +def _test_callback_manager( + manager: BaseCallbackManager, *handlers: FakeCallbackHandler +) -> None: + """Test the CallbackManager.""" + manager.on_llm_start({}, []) + manager.on_llm_end(LLMResult(generations=[])) + manager.on_llm_error(Exception()) + manager.on_chain_start({}, {}) + manager.on_chain_end({}) + manager.on_chain_error(Exception()) + manager.on_tool_start({}, "", "") + manager.on_tool_end("") + manager.on_tool_error(Exception()) + for handler in handlers: + assert handler.starts == 3 + assert handler.ends == 3 + assert handler.errors == 3 + + +def test_callback_manager() -> None: + """Test the CallbackManager.""" + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + manager = CallbackManager([handler1, handler2]) + _test_callback_manager(manager, handler1, handler2) + + +def test_shared_callback_manager() -> None: + """Test the SharedCallbackManager.""" + manager1 = SharedCallbackManager() + manager2 = SharedCallbackManager() + + assert manager1 is manager2 + + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + manager1.add_handler(handler1) + manager2.add_handler(handler2) + _test_callback_manager(manager1, handler1, handler2) diff --git a/tests/unit_tests/test_hyde.py b/tests/unit_tests/test_hyde.py index 91df0f34..91b7bb55 100644 --- a/tests/unit_tests/test_hyde.py +++ b/tests/unit_tests/test_hyde.py @@ -7,8 +7,8 @@ from pydantic import BaseModel from langchain.embeddings.base import Embeddings from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder from langchain.embeddings.hyde.prompts import PROMPT_MAP -from langchain.llms.base import BaseLLM, LLMResult -from langchain.schema import Generation +from langchain.llms.base import BaseLLM +from langchain.schema import Generation, LLMResult class FakeEmbeddings(Embeddings):