forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
409 lines
10 KiB
Python
409 lines
10 KiB
Python
"""Base callback handler that can be used to handle callbacks in langchain."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from uuid import UUID
|
|
|
|
from langchain.schema import (
|
|
AgentAction,
|
|
AgentFinish,
|
|
BaseMessage,
|
|
LLMResult,
|
|
)
|
|
|
|
|
|
class LLMManagerMixin:
|
|
"""Mixin for LLM callbacks."""
|
|
|
|
def on_llm_new_token(
|
|
self,
|
|
token: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
|
|
def on_llm_end(
|
|
self,
|
|
response: LLMResult,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when LLM ends running."""
|
|
|
|
def on_llm_error(
|
|
self,
|
|
error: Union[Exception, KeyboardInterrupt],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when LLM errors."""
|
|
|
|
|
|
class ChainManagerMixin:
|
|
"""Mixin for chain callbacks."""
|
|
|
|
def on_chain_end(
|
|
self,
|
|
outputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when chain ends running."""
|
|
|
|
def on_chain_error(
|
|
self,
|
|
error: Union[Exception, KeyboardInterrupt],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when chain errors."""
|
|
|
|
def on_agent_action(
|
|
self,
|
|
action: AgentAction,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on agent action."""
|
|
|
|
def on_agent_finish(
|
|
self,
|
|
finish: AgentFinish,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on agent end."""
|
|
|
|
|
|
class ToolManagerMixin:
|
|
"""Mixin for tool callbacks."""
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when tool ends running."""
|
|
|
|
def on_tool_error(
|
|
self,
|
|
error: Union[Exception, KeyboardInterrupt],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when tool errors."""
|
|
|
|
|
|
class CallbackManagerMixin:
|
|
"""Mixin for callback manager."""
|
|
|
|
def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when LLM starts running."""
|
|
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List[BaseMessage]],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when a chat model starts running."""
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
|
)
|
|
|
|
def on_chain_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
inputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when chain starts running."""
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when tool starts running."""
|
|
|
|
|
|
class RunManagerMixin:
|
|
"""Mixin for run manager."""
|
|
|
|
def on_text(
|
|
self,
|
|
text: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on arbitrary text."""
|
|
|
|
|
|
class BaseCallbackHandler(
|
|
LLMManagerMixin,
|
|
ChainManagerMixin,
|
|
ToolManagerMixin,
|
|
CallbackManagerMixin,
|
|
RunManagerMixin,
|
|
):
|
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
|
|
|
raise_error: bool = False
|
|
|
|
@property
|
|
def ignore_llm(self) -> bool:
|
|
"""Whether to ignore LLM callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_chain(self) -> bool:
|
|
"""Whether to ignore chain callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_agent(self) -> bool:
|
|
"""Whether to ignore agent callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_chat_model(self) -> bool:
|
|
"""Whether to ignore chat model callbacks."""
|
|
return False
|
|
|
|
|
|
class AsyncCallbackHandler(BaseCallbackHandler):
|
|
"""Async callback handler that can be used to handle callbacks from langchain."""
|
|
|
|
async def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM starts running."""
|
|
|
|
async def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List[BaseMessage]],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when a chat model starts running."""
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
|
)
|
|
|
|
async def on_llm_new_token(
|
|
self,
|
|
token: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
|
|
async def on_llm_end(
|
|
self,
|
|
response: LLMResult,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM ends running."""
|
|
|
|
async def on_llm_error(
|
|
self,
|
|
error: Union[Exception, KeyboardInterrupt],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM errors."""
|
|
|
|
async def on_chain_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
inputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain starts running."""
|
|
|
|
async def on_chain_end(
|
|
self,
|
|
outputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain ends running."""
|
|
|
|
async def on_chain_error(
|
|
self,
|
|
error: Union[Exception, KeyboardInterrupt],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain errors."""
|
|
|
|
async def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool starts running."""
|
|
|
|
async def on_tool_end(
|
|
self,
|
|
output: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool ends running."""
|
|
|
|
async def on_tool_error(
|
|
self,
|
|
error: Union[Exception, KeyboardInterrupt],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool errors."""
|
|
|
|
async def on_text(
|
|
self,
|
|
text: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on arbitrary text."""
|
|
|
|
async def on_agent_action(
|
|
self,
|
|
action: AgentAction,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on agent action."""
|
|
|
|
async def on_agent_finish(
|
|
self,
|
|
finish: AgentFinish,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on agent end."""
|
|
|
|
|
|
class BaseCallbackManager(CallbackManagerMixin):
|
|
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
|
|
|
def __init__(
|
|
self,
|
|
handlers: List[BaseCallbackHandler],
|
|
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
|
parent_run_id: Optional[UUID] = None,
|
|
) -> None:
|
|
"""Initialize callback manager."""
|
|
self.handlers: List[BaseCallbackHandler] = handlers
|
|
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
|
inheritable_handlers or []
|
|
)
|
|
self.parent_run_id: Optional[UUID] = parent_run_id
|
|
|
|
@property
|
|
def is_async(self) -> bool:
|
|
"""Whether the callback manager is async."""
|
|
return False
|
|
|
|
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
|
"""Add a handler to the callback manager."""
|
|
self.handlers.append(handler)
|
|
if inherit:
|
|
self.inheritable_handlers.append(handler)
|
|
|
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
|
"""Remove a handler from the callback manager."""
|
|
self.handlers.remove(handler)
|
|
self.inheritable_handlers.remove(handler)
|
|
|
|
def set_handlers(
|
|
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
|
) -> None:
|
|
"""Set handlers as the only handlers on the callback manager."""
|
|
self.handlers = []
|
|
self.inheritable_handlers = []
|
|
for handler in handlers:
|
|
self.add_handler(handler, inherit=inherit)
|
|
|
|
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
|
"""Set handler as the only handler on the callback manager."""
|
|
self.set_handlers([handler], inherit=inherit)
|