forked from Archives/langchain
b7708bbec6
conceptually, no reason a tool should know what an "agent action" is unless any objections, can change in all callback handlers
102 lines
3.0 KiB
Python
102 lines
3.0 KiB
Python
"""Callback Handler that prints to std out."""
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
|
from langchain.input import print_text
|
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|
|
|
|
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
|
"""Callback Handler that prints to std out."""
|
|
|
|
def __init__(self, color: Optional[str] = None) -> None:
|
|
"""Initialize callback handler."""
|
|
self.color = color
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
"""Print out the prompts."""
|
|
pass
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_llm_error(
|
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
"""Print out that we are entering a chain."""
|
|
class_name = serialized["name"]
|
|
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""Print out that we finished a chain."""
|
|
print("\n\033[1m> Finished chain.\033[0m")
|
|
|
|
def on_chain_error(
|
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_agent_action(
|
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
) -> Any:
|
|
"""Run on agent action."""
|
|
print_text(action.log, color=color if color else self.color)
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: str,
|
|
color: Optional[str] = None,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""If not the final action, print out observation."""
|
|
print_text(f"\n{observation_prefix}")
|
|
print_text(output, color=color if color else self.color)
|
|
print_text(f"\n{llm_prefix}")
|
|
|
|
def on_tool_error(
|
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_text(
|
|
self,
|
|
text: str,
|
|
color: Optional[str] = None,
|
|
end: str = "",
|
|
**kwargs: Optional[str],
|
|
) -> None:
|
|
"""Run when agent ends."""
|
|
print_text(text, color=color if color else self.color, end=end)
|
|
|
|
def on_agent_finish(
|
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
|
) -> None:
|
|
"""Run on agent end."""
|
|
print_text(finish.log, color=color if self.color else color, end="\n")
|