mirror of https://github.com/hwchase17/langchain
Add BaseCallbackHandler and CallbackManager (#478)
Co-authored-by: Ankush Gola <9536492+agola11@users.noreply.github.com>pull/533/head
parent
6d78be0c83
commit
9e04c34e20
@ -0,0 +1,20 @@
|
|||||||
|
"""Callback handlers that allow listening to events in LangChain."""
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
|
||||||
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_manager() -> BaseCallbackManager:
|
||||||
|
"""Return the shared callback manager."""
|
||||||
|
return SharedCallbackManager()
|
||||||
|
|
||||||
|
|
||||||
|
def set_handler(handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler."""
|
||||||
|
callback = get_callback_manager()
|
||||||
|
callback.set_handler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_callback_manager() -> None:
|
||||||
|
"""Set default callback manager."""
|
||||||
|
set_handler(StdOutCallbackHandler())
|
@ -0,0 +1,177 @@
|
|||||||
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackHandler(BaseModel, ABC):
|
||||||
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
ignore_llm: bool = False
|
||||||
|
ignore_chain: bool = False
|
||||||
|
ignore_agent: bool = False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||||
|
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler as the only handler on the callback manager."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackManager(BaseCallbackManager):
|
||||||
|
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
handlers: List[BaseCallbackHandler]
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
handler.on_llm_error(error)
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
handler.on_chain_end(outputs)
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
handler.on_chain_error(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_tool_start(serialized, action, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_tool_error(error)
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on additional input from chains and agents."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_text(text, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_agent_finish(finish, **kwargs)
|
||||||
|
|
||||||
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
self.handlers.append(handler)
|
||||||
|
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
self.handlers.remove(handler)
|
||||||
|
|
||||||
|
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler as the only handler on the callback manager."""
|
||||||
|
self.handlers = [handler]
|
@ -0,0 +1,114 @@
|
|||||||
|
"""A shared CallbackManager."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.base import (
|
||||||
|
BaseCallbackHandler,
|
||||||
|
BaseCallbackManager,
|
||||||
|
CallbackManager,
|
||||||
|
)
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton:
|
||||||
|
"""A thread-safe singleton class that can be inherited from."""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> Any:
|
||||||
|
"""Create a new shared instance of the class."""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
# Another thread could have created the instance
|
||||||
|
# before we acquired the lock. So check that the
|
||||||
|
# instance is still nonexistent.
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||||
|
"""A thread-safe singleton CallbackManager."""
|
||||||
|
|
||||||
|
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_end(response)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_error(error)
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_end(outputs)
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_error(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_error(error)
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on arbitrary text."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_text(text, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_agent_finish(finish, **kwargs)
|
||||||
|
|
||||||
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a callback to the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.add_handler(callback)
|
||||||
|
|
||||||
|
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a callback from the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.remove_handler(callback)
|
||||||
|
|
||||||
|
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler as the only handler on the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.handlers = [handler]
|
@ -0,0 +1,82 @@
|
|||||||
|
"""Callback Handler that prints to std out."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.input import print_text
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that prints to std out."""
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out the prompts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out that we are entering a chain."""
|
||||||
|
class_name = serialized["name"]
|
||||||
|
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Print out that we finished a chain."""
|
||||||
|
print("\n\033[1m> Finished chain.\033[0m")
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
action: AgentAction,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Print out the log in specified color."""
|
||||||
|
print_text(action.log, color=color)
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""If not the final action, print out observation."""
|
||||||
|
print_text(f"\n{observation_prefix}")
|
||||||
|
print_text(output, color=color)
|
||||||
|
print_text(f"\n{llm_prefix}")
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
end: str = "",
|
||||||
|
**kwargs: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Run when agent ends."""
|
||||||
|
print_text(text, color=color, end=end)
|
||||||
|
|
||||||
|
def on_agent_finish(
|
||||||
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
print_text(finish.log, color=color, end="\n")
|
@ -0,0 +1,77 @@
|
|||||||
|
"""Callback Handler that logs to streamlit."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that logs to streamlit."""
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out the prompts."""
|
||||||
|
st.write("Prompts after formatting:")
|
||||||
|
for prompt in prompts:
|
||||||
|
st.write(prompt)
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out that we are entering a chain."""
|
||||||
|
class_name = serialized["name"]
|
||||||
|
st.write(f"Entering new {class_name} chain...")
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Print out that we finished a chain."""
|
||||||
|
st.write("Finished chain.")
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
action: AgentAction,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Print out the log in specified color."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.markdown(action.log.replace("\n", " \n"))
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""If not the final action, print out observation."""
|
||||||
|
st.write(f"{observation_prefix}{output}")
|
||||||
|
st.write(llm_prefix)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on text."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.write(text.replace("\n", " \n"))
|
||||||
|
|
||||||
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.write(finish.log.replace("\n", " \n"))
|
@ -1,71 +0,0 @@
|
|||||||
"""BETA: everything in here is highly experimental, do not rely on."""
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLogger:
|
|
||||||
"""Base logging interface."""
|
|
||||||
|
|
||||||
def log_agent_start(self, text: str, **kwargs: Any) -> None:
|
|
||||||
"""Log the start of an agent interaction."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
||||||
"""Log the end of an agent interaction."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
|
||||||
"""Log agent action decision."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_agent_observation(self, observation: str, **kwargs: Any) -> None:
|
|
||||||
"""Log agent observation."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
|
|
||||||
"""Log LLM inputs."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_llm_response(self, output: str, **kwargs: Any) -> None:
|
|
||||||
"""Log LLM response."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StdOutLogger(BaseLogger):
|
|
||||||
"""Interface for printing things to stdout."""
|
|
||||||
|
|
||||||
def log_agent_start(self, text: str, **kwargs: Any) -> None:
|
|
||||||
"""Print the text to start the agent."""
|
|
||||||
print_text(text)
|
|
||||||
|
|
||||||
def log_agent_action(
|
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Print the log of the action in a certain color."""
|
|
||||||
print_text(action.log, color=color)
|
|
||||||
|
|
||||||
def log_agent_observation(
|
|
||||||
self,
|
|
||||||
observation: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
observation_prefix: Optional[str] = None,
|
|
||||||
llm_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Print the observation in a special color."""
|
|
||||||
print_text(f"\n{observation_prefix}")
|
|
||||||
print_text(observation, color=color)
|
|
||||||
print_text(f"\n{llm_prefix}")
|
|
||||||
|
|
||||||
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
|
|
||||||
"""Print the prompt in green."""
|
|
||||||
print("Prompt after formatting:")
|
|
||||||
print_text(prompt, color="green", end="\n")
|
|
||||||
|
|
||||||
def log_agent_end(
|
|
||||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Log the end of an agent interaction."""
|
|
||||||
print_text(finish.log, color=color)
|
|
@ -0,0 +1 @@
|
|||||||
|
"""Tests for correct functioning of callbacks."""
|
@ -0,0 +1,67 @@
|
|||||||
|
"""A fake callback handler for testing purposes."""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Fake callback handler for testing."""
|
||||||
|
|
||||||
|
starts: int = 0
|
||||||
|
ends: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
text: int = 0
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent is ending."""
|
||||||
|
self.text += 1
|
||||||
|
|
||||||
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends running."""
|
||||||
|
self.ends += 1
|
@ -0,0 +1,97 @@
|
|||||||
|
"""Test CallbackManager."""
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||||
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _test_callback_manager(
|
||||||
|
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||||
|
) -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
manager.on_llm_start({}, [])
|
||||||
|
manager.on_llm_end(LLMResult(generations=[]))
|
||||||
|
manager.on_llm_error(Exception())
|
||||||
|
manager.on_chain_start({"name": "foo"}, {})
|
||||||
|
manager.on_chain_end({})
|
||||||
|
manager.on_chain_error(Exception())
|
||||||
|
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||||
|
manager.on_tool_end("")
|
||||||
|
manager.on_tool_error(Exception())
|
||||||
|
manager.on_agent_finish(AgentFinish({}, ""))
|
||||||
|
for handler in handlers:
|
||||||
|
assert handler.starts == 3
|
||||||
|
assert handler.ends == 4
|
||||||
|
assert handler.errors == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_manager() -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
|
_test_callback_manager(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_llm() -> None:
|
||||||
|
"""Test ignore llm param for callback handlers."""
|
||||||
|
handler1 = FakeCallbackHandler(ignore_llm=True)
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
|
manager.on_llm_start({}, [])
|
||||||
|
manager.on_llm_end(LLMResult(generations=[]))
|
||||||
|
manager.on_llm_error(Exception())
|
||||||
|
assert handler1.starts == 0
|
||||||
|
assert handler1.ends == 0
|
||||||
|
assert handler1.errors == 0
|
||||||
|
assert handler2.starts == 1
|
||||||
|
assert handler2.ends == 1
|
||||||
|
assert handler2.errors == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_chain() -> None:
|
||||||
|
"""Test ignore chain param for callback handlers."""
|
||||||
|
handler1 = FakeCallbackHandler(ignore_chain=True)
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
|
manager.on_chain_start({"name": "foo"}, {})
|
||||||
|
manager.on_chain_end({})
|
||||||
|
manager.on_chain_error(Exception())
|
||||||
|
assert handler1.starts == 0
|
||||||
|
assert handler1.ends == 0
|
||||||
|
assert handler1.errors == 0
|
||||||
|
assert handler2.starts == 1
|
||||||
|
assert handler2.ends == 1
|
||||||
|
assert handler2.errors == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignore_agent() -> None:
|
||||||
|
"""Test ignore agent param for callback handlers."""
|
||||||
|
handler1 = FakeCallbackHandler(ignore_agent=True)
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
|
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||||
|
manager.on_tool_end("")
|
||||||
|
manager.on_tool_error(Exception())
|
||||||
|
manager.on_agent_finish(AgentFinish({}, ""))
|
||||||
|
assert handler1.starts == 0
|
||||||
|
assert handler1.ends == 0
|
||||||
|
assert handler1.errors == 0
|
||||||
|
assert handler2.starts == 1
|
||||||
|
assert handler2.ends == 2
|
||||||
|
assert handler2.errors == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_callback_manager() -> None:
|
||||||
|
"""Test the SharedCallbackManager."""
|
||||||
|
manager1 = SharedCallbackManager()
|
||||||
|
manager2 = SharedCallbackManager()
|
||||||
|
|
||||||
|
assert manager1 is manager2
|
||||||
|
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager1.add_handler(handler1)
|
||||||
|
manager2.add_handler(handler2)
|
||||||
|
_test_callback_manager(manager1, handler1, handler2)
|
@ -0,0 +1,30 @@
|
|||||||
|
"""Test LLM callbacks."""
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_with_callbacks() -> None:
|
||||||
|
"""Test LLM callbacks."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
|
||||||
|
output = llm("foo")
|
||||||
|
assert output == "foo"
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 1
|
||||||
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_with_callbacks_not_verbose() -> None:
|
||||||
|
"""Test LLM callbacks but not verbose."""
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
langchain.verbose = False
|
||||||
|
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
|
||||||
|
output = llm("foo")
|
||||||
|
assert output == "foo"
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 0
|
Loading…
Reference in New Issue