core[patch]: callbacks docstrings (#23375)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-06-26 14:11:06 -07:00 committed by GitHub
parent 1141b08eb8
commit 2a5d59b3d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 707 additions and 98 deletions

View File

@ -6,6 +6,7 @@
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
"""
from langchain_core.callbacks.base import (
AsyncCallbackHandler,
BaseCallbackHandler,

View File

@ -1,4 +1,5 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
"""Base callback handler for LangChain."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
@ -54,7 +55,10 @@ class LLMManagerMixin:
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
def on_llm_end(
@ -65,7 +69,14 @@ class LLMManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
"""Run when LLM ends running.
Args:
response (LLMResult): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
def on_llm_error(
self,
@ -76,11 +87,12 @@ class LLMManagerMixin:
**kwargs: Any,
) -> Any:
"""Run when LLM errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
"""
@ -95,7 +107,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
"""Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_chain_error(
self,
@ -105,7 +123,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
"""Run when chain errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_agent_action(
self,
@ -115,7 +139,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
"""Run on agent action.
Args:
action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_agent_finish(
self,
@ -125,7 +155,13 @@ class ChainManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
"""Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
class ToolManagerMixin:
@ -139,7 +175,13 @@ class ToolManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
"""Run when the tool ends running.
Args:
output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
def on_tool_error(
self,
@ -149,7 +191,13 @@ class ToolManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
"""Run when tool errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
class CallbackManagerMixin:
@ -171,6 +219,15 @@ class CallbackManagerMixin:
**ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model,
you should use on_chat_model_start instead.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The prompts.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
def on_chat_model_start(
@ -188,6 +245,15 @@ class CallbackManagerMixin:
**ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead.
Args:
serialized (Dict[str, Any]): The serialized chat model.
messages (List[List[BaseMessage]]): The messages.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
# NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown
@ -206,7 +272,17 @@ class CallbackManagerMixin:
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
"""Run when the Retriever starts running.
Args:
serialized (Dict[str, Any]): The serialized Retriever.
query (str): The query.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
def on_chain_start(
self,
@ -219,7 +295,17 @@ class CallbackManagerMixin:
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
"""Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
def on_tool_start(
self,
@ -233,7 +319,18 @@ class CallbackManagerMixin:
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
"""Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inputs (Optional[Dict[str, Any]]): The inputs.
kwargs (Any): Additional keyword arguments.
"""
class RunManagerMixin:
@ -247,7 +344,14 @@ class RunManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on arbitrary text."""
"""Run on an arbitrary text.
Args:
text (str): The text.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
def on_retry(
self,
@ -257,7 +361,14 @@ class RunManagerMixin:
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
"""Run on a retry event.
Args:
retry_state (RetryCallState): The retry state.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
class BaseCallbackHandler(
@ -268,11 +379,13 @@ class BaseCallbackHandler(
CallbackManagerMixin,
RunManagerMixin,
):
"""Base callback handler that handles callbacks from LangChain."""
"""Base callback handler for LangChain."""
raise_error: bool = False
"""Whether to raise an error if an exception occurs."""
run_inline: bool = False
"""Whether to run the callback inline."""
@property
def ignore_llm(self) -> bool:
@ -306,7 +419,7 @@ class BaseCallbackHandler(
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that handles callbacks from LangChain."""
"""Async callback handler for LangChain."""
async def on_llm_start(
self,
@ -324,6 +437,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
**ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model,
you should use on_chat_model_start instead.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The prompts.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
async def on_chat_model_start(
@ -341,6 +463,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
**ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead.
Args:
serialized (Dict[str, Any]): The serialized chat model.
messages (List[List[BaseMessage]]): The messages.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
# NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown
@ -358,7 +489,17 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_llm_end(
self,
@ -369,7 +510,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
"""Run when LLM ends running.
Args:
response (LLMResult): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_llm_error(
self,
@ -384,6 +533,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
Args:
error: The error that occurred.
run_id: The run ID. This is the ID of the current run.
parent_run_id: The parent run ID. This is the ID of the parent run.
tags: The tags.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
@ -400,7 +552,17 @@ class AsyncCallbackHandler(BaseCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
"""Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
async def on_chain_end(
self,
@ -411,7 +573,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
"""Run when a chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_chain_error(
self,
@ -422,7 +592,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
"""Run when chain errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_tool_start(
self,
@ -436,7 +614,18 @@ class AsyncCallbackHandler(BaseCallbackHandler):
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
"""Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inputs (Optional[Dict[str, Any]]): The inputs.
kwargs (Any): Additional keyword arguments.
"""
async def on_tool_end(
self,
@ -447,7 +636,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
"""Run when the tool ends running.
Args:
output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_tool_error(
self,
@ -458,7 +655,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
"""Run when tool errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_text(
self,
@ -469,7 +674,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on arbitrary text."""
"""Run on an arbitrary text.
Args:
text (str): The text.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_retry(
self,
@ -479,7 +692,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
"""Run on a retry event.
Args:
retry_state (RetryCallState): The retry state.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
"""
async def on_agent_action(
self,
@ -490,7 +710,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent action."""
"""Run on agent action.
Args:
action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_agent_finish(
self,
@ -501,7 +729,15 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent end."""
"""Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
async def on_retriever_start(
self,
@ -514,7 +750,17 @@ class AsyncCallbackHandler(BaseCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
"""Run on the retriever start.
Args:
serialized (Dict[str, Any]): The serialized retriever.
query (str): The query.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
async def on_retriever_end(
self,
@ -525,7 +771,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
"""Run on the retriever end.
Args:
documents (Sequence[Document]): The documents retrieved.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments."""
async def on_retriever_error(
self,
@ -536,14 +789,22 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
"""Run on retriever error.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.
"""
T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that handles callbacks from LangChain."""
"""Base callback manager for LangChain."""
def __init__(
self,
@ -556,7 +817,18 @@ class BaseCallbackManager(CallbackManagerMixin):
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize callback manager."""
"""Initialize callback manager.
Args:
handlers (List[BaseCallbackHandler]): The handlers.
inheritable_handlers (Optional[List[BaseCallbackHandler]]):
The inheritable handlers. Default is None.
parent_run_id (Optional[UUID]): The parent run ID. Default is None.
tags (Optional[List[str]]): The tags. Default is None.
inheritable_tags (Optional[List[str]]): The inheritable tags.
Default is None.
metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
"""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or []
@ -585,31 +857,56 @@ class BaseCallbackManager(CallbackManagerMixin):
return False
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
"""Add a handler to the callback manager.
Args:
handler (BaseCallbackHandler): The handler to add.
inherit (bool): Whether to inherit the handler. Default is True.
"""
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
self.inheritable_handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
"""Remove a handler from the callback manager.
Args:
handler (BaseCallbackHandler): The handler to remove.
"""
self.handlers.remove(handler)
self.inheritable_handlers.remove(handler)
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager."""
"""Set handlers as the only handlers on the callback manager.
Args:
handlers (List[BaseCallbackHandler]): The handlers to set.
inherit (bool): Whether to inherit the handlers. Default is True.
"""
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
"""Set handler as the only handler on the callback manager.
Args:
handler (BaseCallbackHandler): The handler to set.
inherit (bool): Whether to inherit the handler. Default is True.
"""
self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
"""Add tags to the callback manager.
Args:
tags (List[str]): The tags to add.
inherit (bool): Whether to inherit the tags. Default is True.
"""
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
@ -618,16 +915,32 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
"""Remove tags from the callback manager.
Args:
tags (List[str]): The tags to remove.
"""
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
"""Add metadata to the callback manager.
Args:
metadata (Dict[str, Any]): The metadata to add.
inherit (bool): Whether to inherit the metadata. Default is True.
"""
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None:
"""Remove metadata from the callback manager.
Args:
keys (List[str]): The keys to remove.
"""
for key in keys:
self.metadata.pop(key)
self.inheritable_metadata.pop(key)

View File

@ -10,12 +10,23 @@ from langchain_core.utils.input import print_text
class FileCallbackHandler(BaseCallbackHandler):
"""Callback Handler that writes to a file."""
"""Callback Handler that writes to a file.
Parameters:
file: The file to write to.
color: The color to use for the text.
"""
def __init__(
self, filename: str, mode: str = "a", color: Optional[str] = None
) -> None:
"""Initialize callback handler."""
"""Initialize callback handler.
Args:
filename: The filename to write to.
mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None.
"""
self.file = cast(TextIO, open(filename, mode, encoding="utf-8"))
self.color = color
@ -26,7 +37,13 @@ class FileCallbackHandler(BaseCallbackHandler):
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
"""Print out that we are entering a chain.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print_text(
f"\n\n\033[1m> Entering new {class_name} chain...\033[0m",
@ -35,13 +52,25 @@ class FileCallbackHandler(BaseCallbackHandler):
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
"""Print out that we finished a chain.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
"""Run on agent action.
Args:
action (AgentAction): The agent action.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(action.log, color=color or self.color, file=self.file)
def on_tool_end(
@ -52,7 +81,18 @@ class FileCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
"""If not the final action, print out observation.
Args:
output (str): The output to print.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
observation_prefix (Optional[str], optional): The observation prefix.
Defaults to None.
llm_prefix (Optional[str], optional): The LLM prefix.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if observation_prefix is not None:
print_text(f"\n{observation_prefix}", file=self.file)
print_text(output, color=color or self.color, file=self.file)
@ -62,11 +102,26 @@ class FileCallbackHandler(BaseCallbackHandler):
def on_text(
self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any
) -> None:
"""Run when agent ends."""
"""Run when the agent ends.
Args:
text (str): The text to print.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
end (str, optional): The end character. Defaults to "".
**kwargs (Any): Additional keyword arguments.
"""
print_text(text, color=color or self.color, end=end, file=self.file)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
"""Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
color (Optional[str], optional): The color to use for the text.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(finish.log, color=color or self.color, end="\n", file=self.file)

View File

@ -77,7 +77,9 @@ def trace_as_chain_group(
Args:
group_name (str): The name of the chain group.
callback_manager (CallbackManager, optional): The callback manager to use.
Defaults to None.
inputs (Dict[str, Any], optional): The inputs to the chain group.
Defaults to None.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
@ -155,7 +157,9 @@ async def atrace_as_chain_group(
Args:
group_name (str): The name of the chain group.
callback_manager (AsyncCallbackManager, optional): The async callback manager to use,
which manages tracing and other callback behavior.
which manages tracing and other callback behavior. Defaults to None.
inputs (Dict[str, Any], optional): The inputs to the chain group.
Defaults to None.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
@ -218,7 +222,13 @@ Func = TypeVar("Func", bound=Callable)
def shielded(func: Func) -> Func:
"""
Makes so an awaitable method is always shielded from cancellation
Makes so an awaitable method is always shielded from cancellation.
Args:
func (Callable): The function to shield.
Returns:
Callable: The shielded function
"""
@functools.wraps(func)
@ -237,14 +247,14 @@ def handle_event(
) -> None:
"""Generic event handler for CallbackManager.
Note: This function is used by langserve to handle events.
Note: This function is used by LangServe to handle events.
Args:
handlers: The list of handlers that will handle the event
event_name: The name of the event (e.g., "on_llm_start")
handlers: The list of handlers that will handle the event.
event_name: The name of the event (e.g., "on_llm_start").
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event
*args: The arguments to pass to the event handler
that if True will cause the handler to be skipped for the given event.
*args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler
"""
coros: List[Coroutine[Any, Any, Any]] = []
@ -394,17 +404,17 @@ async def ahandle_event(
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for AsyncCallbackManager.
"""Async generic event handler for AsyncCallbackManager.
Note: This function is used by langserve to handle events.
Note: This function is used by LangServe to handle events.
Args:
handlers: The list of handlers that will handle the event
event_name: The name of the event (e.g., "on_llm_start")
handlers: The list of handlers that will handle the event.
event_name: The name of the event (e.g., "on_llm_start").
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
that if True will cause the handler to be skipped for the given event.
*args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler.
"""
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
@ -452,10 +462,13 @@ class BaseRunManager(RunManagerMixin):
The list of inheritable handlers.
parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None.
tags (Optional[List[str]]): The list of tags.
tags (Optional[List[str]]): The list of tags. Defaults to None.
inheritable_tags (Optional[List[str]]): The list of inheritable tags.
Defaults to None.
metadata (Optional[Dict[str, Any]]): The metadata.
Defaults to None.
inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata.
Defaults to None.
"""
self.run_id = run_id
self.handlers = handlers
@ -492,10 +505,11 @@ class RunManager(BaseRunManager):
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received.
"""Run when a text is received.
Args:
text (str): The received text.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the callback.
@ -516,6 +530,12 @@ class RunManager(BaseRunManager):
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
"""Run when a retry is received.
Args:
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
"on_retry",
@ -566,10 +586,11 @@ class AsyncRunManager(BaseRunManager, ABC):
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received.
"""Run when a text is received.
Args:
text (str): The received text.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the callback.
@ -590,6 +611,12 @@ class AsyncRunManager(BaseRunManager, ABC):
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
"""Async run when a retry is received.
Args:
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
"on_retry",
@ -638,6 +665,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
token (str): The new token.
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
@ -656,6 +686,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
@ -725,6 +756,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
token (str): The new token.
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
@ -744,6 +778,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
@ -793,6 +828,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
@ -814,6 +850,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
@ -831,6 +868,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
action (AgentAction): The agent action.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the callback.
@ -851,6 +889,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
finish (AgentFinish): The agent finish.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the callback.
@ -891,10 +930,11 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
"""Run when chain ends running.
"""Run when a chain ends running.
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
@ -917,6 +957,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
@ -935,6 +976,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
action (AgentAction): The agent action.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the callback.
@ -956,6 +998,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
finish (AgentFinish): The agent finish.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the callback.
@ -980,10 +1023,11 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
output: Any,
**kwargs: Any,
) -> None:
"""Run when tool ends running.
"""Run when the tool ends running.
Args:
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
@ -1005,6 +1049,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
@ -1040,10 +1085,11 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
@shielded
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running.
"""Async run when the tool ends running.
Args:
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
@ -1066,6 +1112,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
@ -1087,7 +1134,12 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
documents: Sequence[Document],
**kwargs: Any,
) -> None:
"""Run when retriever ends running."""
"""Run when retriever ends running.
Args:
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
"on_retriever_end",
@ -1104,7 +1156,12 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
"""Run when retriever errors.
Args:
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
handle_event(
self.handlers,
"on_retriever_error",
@ -1144,7 +1201,12 @@ class AsyncCallbackManagerForRetrieverRun(
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
"""Run when retriever ends running."""
"""Run when the retriever ends running.
Args:
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
"on_retriever_end",
@ -1162,7 +1224,12 @@ class AsyncCallbackManagerForRetrieverRun(
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
"""Run when retriever errors.
Args:
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
await ahandle_event(
self.handlers,
"on_retriever_error",
@ -1176,7 +1243,7 @@ class AsyncCallbackManagerForRetrieverRun(
class CallbackManager(BaseCallbackManager):
"""Callback manager that handles callbacks from LangChain."""
"""Callback manager for LangChain."""
def on_llm_start(
self,
@ -1191,6 +1258,7 @@ class CallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The list of prompts.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
List[CallbackManagerForLLMRun]: A callback manager for each
@ -1241,6 +1309,7 @@ class CallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
List[CallbackManagerForLLMRun]: A callback manager for each
@ -1295,6 +1364,7 @@ class CallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized chain.
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
CallbackManagerForChainRun: The callback manager for the chain run.
@ -1347,6 +1417,7 @@ class CallbackManager(BaseCallbackManager):
input is needed.
If provided, the inputs are expected to be formatted as a dict.
The keys will correspond to the named-arguments in the tool.
**kwargs (Any): Additional keyword arguments.
Returns:
CallbackManagerForToolRun: The callback manager for the tool run.
@ -1387,7 +1458,15 @@ class CallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
"""Run when the retriever starts running.
Args:
serialized (Dict[str, Any]): The serialized retriever.
query (str): The query.
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if run_id is None:
run_id = uuid.uuid4()
@ -1470,6 +1549,16 @@ class CallbackManagerForChainGroup(CallbackManager):
parent_run_manager: CallbackManagerForChainRun,
**kwargs: Any,
) -> None:
"""Initialize the callback manager.
Args:
handlers (List[BaseCallbackHandler]): The list of handlers.
inheritable_handlers (Optional[List[BaseCallbackHandler]]): The list of
inheritable handlers. Defaults to None.
parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None.
parent_run_manager (CallbackManagerForChainRun): The parent run manager.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__(
handlers,
inheritable_handlers,
@ -1480,6 +1569,7 @@ class CallbackManagerForChainGroup(CallbackManager):
self.ended = False
def copy(self) -> CallbackManagerForChainGroup:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
@ -1496,6 +1586,7 @@ class CallbackManagerForChainGroup(CallbackManager):
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
self.ended = True
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
@ -1509,6 +1600,7 @@ class CallbackManagerForChainGroup(CallbackManager):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
self.ended = True
return self.parent_run_manager.on_chain_error(error, **kwargs)
@ -1535,6 +1627,7 @@ class AsyncCallbackManager(BaseCallbackManager):
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The list of prompts.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
List[AsyncCallbackManagerForLLMRun]: The list of async
@ -1591,12 +1684,13 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
"""Async run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
List[AsyncCallbackManagerForLLMRun]: The list of
@ -1651,12 +1745,13 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForChainRun:
"""Run when chain starts running.
"""Async run when chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Union[Dict[str, Any], Any]): The inputs to the chain.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
AsyncCallbackManagerForChainRun: The async callback manager
@ -1697,7 +1792,7 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForToolRun:
"""Run when tool starts running.
"""Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
@ -1705,6 +1800,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
AsyncCallbackManagerForToolRun: The async callback manager
@ -1745,7 +1841,19 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
"""Run when the retriever starts running.
Args:
serialized (Dict[str, Any]): The serialized retriever.
query (str): The query.
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
AsyncCallbackManagerForRetrieverRun: The async callback manager
for the retriever run.
"""
if run_id is None:
run_id = uuid.uuid4()
@ -1828,6 +1936,17 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
parent_run_manager: AsyncCallbackManagerForChainRun,
**kwargs: Any,
) -> None:
"""Initialize the async callback manager.
Args:
handlers (List[BaseCallbackHandler]): The list of handlers.
inheritable_handlers (Optional[List[BaseCallbackHandler]]): The list of
inheritable handlers. Defaults to None.
parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None.
parent_run_manager (AsyncCallbackManagerForChainRun):
The parent run manager.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__(
handlers,
inheritable_handlers,
@ -1838,6 +1957,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
self.ended = False
def copy(self) -> AsyncCallbackManagerForChainGroup:
"""Copy the async callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
@ -1856,6 +1976,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
self.ended = True
await self.parent_run_manager.on_chain_end(outputs, **kwargs)
@ -1869,6 +1990,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
self.ended = True
await self.parent_run_manager.on_chain_error(error, **kwargs)

View File

@ -15,24 +15,45 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
"""Initialize callback handler.
Args:
color: The color to use for the text. Defaults to None.
"""
self.color = color
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
"""Print out that we are entering a chain.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
"""Print out that we finished a chain.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
"""Run on agent action.
Args:
action (AgentAction): The agent action.
color (Optional[str]): The color to use for the text. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(action.log, color=color or self.color)
def on_tool_end(
@ -43,7 +64,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
"""If not the final action, print out observation.
Args:
output (Any): The output to print.
color (Optional[str]): The color to use for the text. Defaults to None.
observation_prefix (Optional[str]): The observation prefix.
Defaults to None.
llm_prefix (Optional[str]): The LLM prefix. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
output = str(output)
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
@ -58,11 +88,24 @@ class StdOutCallbackHandler(BaseCallbackHandler):
end: str = "",
**kwargs: Any,
) -> None:
"""Run when agent ends."""
"""Run when the agent ends.
Args:
text (str): The text to print.
color (Optional[str]): The color to use for the text. Defaults to None.
end (str): The end character to use. Defaults to "".
**kwargs (Any): Additional keyword arguments.
"""
print_text(text, color=color or self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
"""Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
color (Optional[str]): The color to use for the text. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
print_text(finish.log, color=color or self.color, end="\n")

View File

@ -1,4 +1,5 @@
"""Callback Handler streams to stdout on new llm token."""
from __future__ import annotations
import sys
@ -18,7 +19,13 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
"""Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
prompts (List[str]): The prompts to run.
**kwargs (Any): Additional keyword arguments.
"""
def on_chat_model_start(
self,
@ -26,47 +33,115 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
"""Run when LLM starts running.
Args:
serialized (Dict[str, Any]): The serialized LLM.
messages (List[List[BaseMessage]]): The messages to run.
**kwargs (Any): Additional keyword arguments.
"""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
**kwargs (Any): Additional keyword arguments.
"""
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
"""Run when LLM ends running.
Args:
response (LLMResult): The response from the LLM.
**kwargs (Any): Additional keyword arguments.
"""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
"""Run when LLM errors.
Args:
error (BaseException): The error that occurred.
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
"""Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
"""Run when a chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
"""Run when chain errors.
Args:
error (BaseException): The error that occurred.
**kwargs (Any): Additional keyword arguments.
"""
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
"""Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
**kwargs (Any): Additional keyword arguments.
"""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
"""Run on agent action.
Args:
action (AgentAction): The agent action.
**kwargs (Any): Additional keyword arguments.
"""
pass
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
"""Run when tool ends running.
Args:
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
"""Run when tool errors.
Args:
error (BaseException): The error that occurred.
**kwargs (Any): Additional keyword arguments.
"""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
"""Run on an arbitrary text.
Args:
text (str): The text to print.
**kwargs (Any): Additional keyword arguments.
"""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
"""Run on the agent end.
Args:
finish (AgentFinish): The agent finish.
**kwargs (Any): Additional keyword arguments.
"""