Add BaseCallbackHandler and CallbackManager (#476)

harrison/callback-updates
Ankush Gola 1 year ago committed by GitHub
parent 0f1df0dc2c
commit 46b31626b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.tools import Tool
from langchain.callbacks import get_callback_manager
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
@ -213,8 +214,12 @@ class AgentExecutor(Chain, BaseModel):
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
get_callback_manager().on_tool_start(
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
)
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
get_callback_manager().on_tool_end(observation)
color = color_mapping[output.tool]
else:
observation = f"{output.tool} is not a valid tool, try another one."

@ -0,0 +1,8 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()

@ -0,0 +1,137 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain.schema import LLMResult
class BaseCallbackHandler(ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain."""
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
@abstractmethod
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize the callback manager."""
self.handlers = handlers
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
handler.on_llm_start(serialized, prompts, **extra)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
handler.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
handler.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
handler.on_chain_start(serialized, inputs, **extra)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
handler.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
for handler in self.handlers:
handler.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
handler.on_tool_start(serialized, action, tool_input, **extra)
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
handler.on_tool_end(output)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
for handler in self.handlers:
handler.on_tool_error(error)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)

@ -0,0 +1,101 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import LLMResult
class Singleton:
"""A thread-safe singleton class that can be inherited from."""
_instance = None
_lock = threading.Lock()
def __new__(cls) -> Any:
"""Create a new shared instance of the class."""
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager([])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **extra)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
with self._lock:
self._callback_manager.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **extra)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
with self._lock:
self._callback_manager.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(
serialized, action, tool_input, **extra
)
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:
self._callback_manager.add_handler(callback)
def remove_handler(self, callback: BaseCallbackHandler) -> None:
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field
import langchain
from langchain.callbacks import get_callback_manager
class Memory(BaseModel, ABC):
@ -109,7 +110,9 @@ class Chain(BaseModel, ABC):
print(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
)
get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs)
outputs = self._call(inputs)
get_callback_manager().on_chain_end(outputs)
if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
self._validate_outputs(outputs)

@ -5,8 +5,9 @@ from pydantic import BaseModel, Extra
import langchain
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM, LLMResult
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import LLMResult
class LLMChain(Chain, BaseModel):

@ -2,23 +2,14 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union
import yaml
from pydantic import BaseModel, Extra
import langchain
from langchain.schema import Generation
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
from langchain.callbacks import get_callback_manager
from langchain.schema import Generation, LLMResult
class BaseLLM(BaseModel, ABC):
@ -48,7 +39,12 @@ class BaseLLM(BaseModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
return self._generate(prompts, stop=stop)
get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
get_callback_manager().on_llm_end(output)
return output
params = self._llm_dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
@ -62,7 +58,11 @@ class BaseLLM(BaseModel, ABC):
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
new_results = self._generate(missing_prompts, stop=stop)
get_callback_manager().on_llm_end(new_results)
for i, result in enumerate(new_results.generations):
existing_prompts[i] = result
prompt = prompts[i]

@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env

@ -1,6 +1,6 @@
"""Common schema objects."""
from typing import NamedTuple
from typing import List, NamedTuple, Optional
class AgentAction(NamedTuple):
@ -24,3 +24,13 @@ class Generation(NamedTuple):
text: str
"""Generated text output."""
# TODO: add log probs
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

@ -0,0 +1 @@
"""Tests for correct functioning of callbacks."""

@ -0,0 +1,107 @@
"""Test CallbackManager."""
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
def __init__(self) -> None:
"""Initialize the mock callback handler."""
self.starts = 0
self.ends = 0
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, "", "")
manager.on_tool_end("")
manager.on_tool_error(Exception())
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 3
assert handler.errors == 3
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_shared_callback_manager() -> None:
"""Test the SharedCallbackManager."""
manager1 = SharedCallbackManager()
manager2 = SharedCallbackManager()
assert manager1 is manager2
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)

@ -7,8 +7,8 @@ from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
class FakeEmbeddings(Embeddings):

Loading…
Cancel
Save