Add BaseCallbackHandler and CallbackManager (#476)

This commit is contained in:
Ankush Gola 2022-12-29 15:11:37 -05:00 committed by GitHub
parent 0f1df0dc2c
commit 46b31626b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 392 additions and 19 deletions

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel, root_validator
import langchain import langchain
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks import get_callback_manager
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping from langchain.input import get_color_mapping
@ -213,8 +214,12 @@ class AgentExecutor(Chain, BaseModel):
# And then we lookup the tool # And then we lookup the tool
if output.tool in name_to_tool_map: if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool] chain = name_to_tool_map[output.tool]
get_callback_manager().on_tool_start(
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
)
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = chain(output.tool_input) observation = chain(output.tool_input)
get_callback_manager().on_tool_end(observation)
color = color_mapping[output.tool] color = color_mapping[output.tool]
else: else:
observation = f"{output.tool} is not a valid tool, try another one." observation = f"{output.tool} is not a valid tool, try another one."

View File

@ -0,0 +1,8 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()

137
langchain/callbacks/base.py Normal file
View File

@ -0,0 +1,137 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain.schema import LLMResult
class BaseCallbackHandler(ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain."""
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
@abstractmethod
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize the callback manager."""
self.handlers = handlers
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
handler.on_llm_start(serialized, prompts, **extra)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
handler.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
handler.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
handler.on_chain_start(serialized, inputs, **extra)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
handler.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
for handler in self.handlers:
handler.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
handler.on_tool_start(serialized, action, tool_input, **extra)
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
handler.on_tool_end(output)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
for handler in self.handlers:
handler.on_tool_error(error)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)

View File

@ -0,0 +1,101 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import LLMResult
class Singleton:
"""A thread-safe singleton class that can be inherited from."""
_instance = None
_lock = threading.Lock()
def __new__(cls) -> Any:
"""Create a new shared instance of the class."""
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager([])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **extra)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
with self._lock:
self._callback_manager.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **extra)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
with self._lock:
self._callback_manager.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(
serialized, action, tool_input, **extra
)
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:
self._callback_manager.add_handler(callback)
def remove_handler(self, callback: BaseCallbackHandler) -> None:
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field
import langchain import langchain
from langchain.callbacks import get_callback_manager
class Memory(BaseModel, ABC): class Memory(BaseModel, ABC):
@ -109,7 +110,9 @@ class Chain(BaseModel, ABC):
print( print(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
) )
get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs)
outputs = self._call(inputs) outputs = self._call(inputs)
get_callback_manager().on_chain_end(outputs)
if self.verbose: if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
self._validate_outputs(outputs) self._validate_outputs(outputs)

View File

@ -5,8 +5,9 @@ from pydantic import BaseModel, Extra
import langchain import langchain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM, LLMResult from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import LLMResult
class LLMChain(Chain, BaseModel): class LLMChain(Chain, BaseModel):

View File

@ -2,23 +2,14 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Union
import yaml import yaml
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
import langchain import langchain
from langchain.schema import Generation from langchain.callbacks import get_callback_manager
from langchain.schema import Generation, LLMResult
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class BaseLLM(BaseModel, ABC): class BaseLLM(BaseModel, ABC):
@ -48,7 +39,12 @@ class BaseLLM(BaseModel, ABC):
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
return self._generate(prompts, stop=stop) get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
get_callback_manager().on_llm_end(output)
return output
params = self._llm_dict() params = self._llm_dict()
params["stop"] = stop params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
@ -62,7 +58,11 @@ class BaseLLM(BaseModel, ABC):
else: else:
missing_prompts.append(prompt) missing_prompts.append(prompt)
missing_prompt_idxs.append(i) missing_prompt_idxs.append(i)
get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
new_results = self._generate(missing_prompts, stop=stop) new_results = self._generate(missing_prompts, stop=stop)
get_callback_manager().on_llm_end(new_results)
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[i] = result existing_prompts[i] = result
prompt = prompts[i] prompt = prompts[i]

View File

@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import BaseLLM, LLMResult from langchain.llms.base import BaseLLM
from langchain.schema import Generation from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env

View File

@ -1,6 +1,6 @@
"""Common schema objects.""" """Common schema objects."""
from typing import NamedTuple from typing import List, NamedTuple, Optional
class AgentAction(NamedTuple): class AgentAction(NamedTuple):
@ -24,3 +24,13 @@ class Generation(NamedTuple):
text: str text: str
"""Generated text output.""" """Generated text output."""
# TODO: add log probs # TODO: add log probs
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

View File

@ -0,0 +1 @@
"""Tests for correct functioning of callbacks."""

View File

@ -0,0 +1,107 @@
"""Test CallbackManager."""
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
def __init__(self) -> None:
"""Initialize the mock callback handler."""
self.starts = 0
self.ends = 0
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, "", "")
manager.on_tool_end("")
manager.on_tool_error(Exception())
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 3
assert handler.errors == 3
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_shared_callback_manager() -> None:
"""Test the SharedCallbackManager."""
manager1 = SharedCallbackManager()
manager2 = SharedCallbackManager()
assert manager1 is manager2
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)

View File

@ -7,8 +7,8 @@ from pydantic import BaseModel
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM, LLMResult from langchain.llms.base import BaseLLM
from langchain.schema import Generation from langchain.schema import Generation, LLMResult
class FakeEmbeddings(Embeddings): class FakeEmbeddings(Embeddings):