forked from Archives/langchain
Add BaseCallbackHandler and CallbackManager (#476)
parent
0f1df0dc2c
commit
46b31626b5
@ -0,0 +1,8 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
@ -0,0 +1,137 @@
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@abstractmethod
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
|
||||
@abstractmethod
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_llm_start(serialized, prompts, **extra)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_llm_end(response)
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
for handler in self.handlers:
|
||||
handler.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_chain_start(serialized, inputs, **extra)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
for handler in self.handlers:
|
||||
handler.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_tool_start(serialized, action, tool_input, **extra)
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_tool_end(output)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
for handler in self.handlers:
|
||||
handler.on_tool_error(error)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
@ -0,0 +1,101 @@
|
||||
"""A shared CallbackManager."""
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
"""A thread-safe singleton class that can be inherited from."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Any:
|
||||
"""Create a new shared instance of the class."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
# Another thread could have created the instance
|
||||
# before we acquired the lock. So check that the
|
||||
# instance is still nonexistent.
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
"""A thread-safe singleton CallbackManager."""
|
||||
|
||||
_callback_manager: CallbackManager = CallbackManager([])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **extra)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_end(response)
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **extra)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(
|
||||
serialized, action, tool_input, **extra
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
"""Run when tool ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_end(output)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_error(error)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.add_handler(callback)
|
||||
|
||||
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Remove a callback from the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.remove_handler(callback)
|
@ -0,0 +1 @@
|
||||
"""Tests for correct functioning of callbacks."""
|
@ -0,0 +1,107 @@
|
||||
"""Test CallbackManager."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock callback handler."""
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, "", "")
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
for handler in handlers:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 3
|
||||
assert handler.errors == 3
|
||||
|
||||
|
||||
def test_callback_manager() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_shared_callback_manager() -> None:
|
||||
"""Test the SharedCallbackManager."""
|
||||
manager1 = SharedCallbackManager()
|
||||
manager2 = SharedCallbackManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
Loading…
Reference in New Issue