mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
612 lines
16 KiB
Python
612 lines
16 KiB
Python
"""Base callback handler that can be used to handle callbacks in langchain."""
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
|
from uuid import UUID
|
|
|
|
from tenacity import RetryCallState
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.documents import Document
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
|
|
|
|
|
class RetrieverManagerMixin:
|
|
"""Mixin for Retriever callbacks."""
|
|
|
|
def on_retriever_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when Retriever errors."""
|
|
|
|
def on_retriever_end(
|
|
self,
|
|
documents: Sequence[Document],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when Retriever ends running."""
|
|
|
|
|
|
class LLMManagerMixin:
|
|
"""Mixin for LLM callbacks."""
|
|
|
|
def on_llm_new_token(
|
|
self,
|
|
token: str,
|
|
*,
|
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on new LLM token. Only available when streaming is enabled.
|
|
|
|
Args:
|
|
token (str): The new token.
|
|
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
|
containing content and other information.
|
|
"""
|
|
|
|
def on_llm_end(
|
|
self,
|
|
response: LLMResult,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when LLM ends running."""
|
|
|
|
def on_llm_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when LLM errors.
|
|
Args:
|
|
error (BaseException): The error that occurred.
|
|
kwargs (Any): Additional keyword arguments.
|
|
- response (LLMResult): The response which was generated before
|
|
the error occurred.
|
|
"""
|
|
|
|
|
|
class ChainManagerMixin:
|
|
"""Mixin for chain callbacks."""
|
|
|
|
def on_chain_end(
|
|
self,
|
|
outputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when chain ends running."""
|
|
|
|
def on_chain_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when chain errors."""
|
|
|
|
def on_agent_action(
|
|
self,
|
|
action: AgentAction,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on agent action."""
|
|
|
|
def on_agent_finish(
|
|
self,
|
|
finish: AgentFinish,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on agent end."""
|
|
|
|
|
|
class ToolManagerMixin:
|
|
"""Mixin for tool callbacks."""
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when tool ends running."""
|
|
|
|
def on_tool_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when tool errors."""
|
|
|
|
|
|
class CallbackManagerMixin:
|
|
"""Mixin for callback manager."""
|
|
|
|
def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when LLM starts running."""
|
|
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List[BaseMessage]],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when a chat model starts running."""
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
|
)
|
|
|
|
def on_retriever_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
query: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when Retriever starts running."""
|
|
|
|
def on_chain_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
inputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when chain starts running."""
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when tool starts running."""
|
|
|
|
|
|
class RunManagerMixin:
|
|
"""Mixin for run manager."""
|
|
|
|
def on_text(
|
|
self,
|
|
text: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on arbitrary text."""
|
|
|
|
def on_retry(
|
|
self,
|
|
retry_state: RetryCallState,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on a retry event."""
|
|
|
|
|
|
class BaseCallbackHandler(
|
|
LLMManagerMixin,
|
|
ChainManagerMixin,
|
|
ToolManagerMixin,
|
|
RetrieverManagerMixin,
|
|
CallbackManagerMixin,
|
|
RunManagerMixin,
|
|
):
|
|
"""Base callback handler that handles callbacks from LangChain."""
|
|
|
|
raise_error: bool = False
|
|
|
|
run_inline: bool = False
|
|
|
|
@property
|
|
def ignore_llm(self) -> bool:
|
|
"""Whether to ignore LLM callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_retry(self) -> bool:
|
|
"""Whether to ignore retry callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_chain(self) -> bool:
|
|
"""Whether to ignore chain callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_agent(self) -> bool:
|
|
"""Whether to ignore agent callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_retriever(self) -> bool:
|
|
"""Whether to ignore retriever callbacks."""
|
|
return False
|
|
|
|
@property
|
|
def ignore_chat_model(self) -> bool:
|
|
"""Whether to ignore chat model callbacks."""
|
|
return False
|
|
|
|
|
|
class AsyncCallbackHandler(BaseCallbackHandler):
|
|
"""Async callback handler that handles callbacks from LangChain."""
|
|
|
|
async def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM starts running."""
|
|
|
|
async def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List[BaseMessage]],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when a chat model starts running."""
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
|
)
|
|
|
|
async def on_llm_new_token(
|
|
self,
|
|
token: str,
|
|
*,
|
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
|
|
async def on_llm_end(
|
|
self,
|
|
response: LLMResult,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM ends running."""
|
|
|
|
async def on_llm_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM errors.
|
|
Args:
|
|
error (BaseException): The error that occurred.
|
|
kwargs (Any): Additional keyword arguments.
|
|
- response (LLMResult): The response which was generated before
|
|
the error occurred.
|
|
"""
|
|
|
|
async def on_chain_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
inputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain starts running."""
|
|
|
|
async def on_chain_end(
|
|
self,
|
|
outputs: Dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain ends running."""
|
|
|
|
async def on_chain_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain errors."""
|
|
|
|
async def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool starts running."""
|
|
|
|
async def on_tool_end(
|
|
self,
|
|
output: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool ends running."""
|
|
|
|
async def on_tool_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool errors."""
|
|
|
|
async def on_text(
|
|
self,
|
|
text: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on arbitrary text."""
|
|
|
|
async def on_retry(
|
|
self,
|
|
retry_state: RetryCallState,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run on a retry event."""
|
|
|
|
async def on_agent_action(
|
|
self,
|
|
action: AgentAction,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on agent action."""
|
|
|
|
async def on_agent_finish(
|
|
self,
|
|
finish: AgentFinish,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on agent end."""
|
|
|
|
async def on_retriever_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
query: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on retriever start."""
|
|
|
|
async def on_retriever_end(
|
|
self,
|
|
documents: Sequence[Document],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on retriever end."""
|
|
|
|
async def on_retriever_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run on retriever error."""
|
|
|
|
|
|
T = TypeVar("T", bound="BaseCallbackManager")
|
|
|
|
|
|
class BaseCallbackManager(CallbackManagerMixin):
|
|
"""Base callback manager that handles callbacks from LangChain."""
|
|
|
|
def __init__(
|
|
self,
|
|
handlers: List[BaseCallbackHandler],
|
|
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
|
parent_run_id: Optional[UUID] = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
inheritable_tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Initialize callback manager."""
|
|
self.handlers: List[BaseCallbackHandler] = handlers
|
|
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
|
inheritable_handlers or []
|
|
)
|
|
self.parent_run_id: Optional[UUID] = parent_run_id
|
|
self.tags = tags or []
|
|
self.inheritable_tags = inheritable_tags or []
|
|
self.metadata = metadata or {}
|
|
self.inheritable_metadata = inheritable_metadata or {}
|
|
|
|
def copy(self: T) -> T:
|
|
"""Copy the callback manager."""
|
|
return self.__class__(
|
|
handlers=self.handlers,
|
|
inheritable_handlers=self.inheritable_handlers,
|
|
parent_run_id=self.parent_run_id,
|
|
tags=self.tags,
|
|
inheritable_tags=self.inheritable_tags,
|
|
metadata=self.metadata,
|
|
inheritable_metadata=self.inheritable_metadata,
|
|
)
|
|
|
|
@property
|
|
def is_async(self) -> bool:
|
|
"""Whether the callback manager is async."""
|
|
return False
|
|
|
|
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
|
"""Add a handler to the callback manager."""
|
|
if handler not in self.handlers:
|
|
self.handlers.append(handler)
|
|
if inherit and handler not in self.inheritable_handlers:
|
|
self.inheritable_handlers.append(handler)
|
|
|
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
|
"""Remove a handler from the callback manager."""
|
|
self.handlers.remove(handler)
|
|
self.inheritable_handlers.remove(handler)
|
|
|
|
def set_handlers(
|
|
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
|
) -> None:
|
|
"""Set handlers as the only handlers on the callback manager."""
|
|
self.handlers = []
|
|
self.inheritable_handlers = []
|
|
for handler in handlers:
|
|
self.add_handler(handler, inherit=inherit)
|
|
|
|
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
|
"""Set handler as the only handler on the callback manager."""
|
|
self.set_handlers([handler], inherit=inherit)
|
|
|
|
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
|
for tag in tags:
|
|
if tag in self.tags:
|
|
self.remove_tags([tag])
|
|
self.tags.extend(tags)
|
|
if inherit:
|
|
self.inheritable_tags.extend(tags)
|
|
|
|
def remove_tags(self, tags: List[str]) -> None:
|
|
for tag in tags:
|
|
self.tags.remove(tag)
|
|
self.inheritable_tags.remove(tag)
|
|
|
|
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
|
self.metadata.update(metadata)
|
|
if inherit:
|
|
self.inheritable_metadata.update(metadata)
|
|
|
|
def remove_metadata(self, keys: List[str]) -> None:
|
|
for key in keys:
|
|
self.metadata.pop(key)
|
|
self.inheritable_metadata.pop(key)
|
|
|
|
|
|
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|