forked from Archives/langchain
WIP: stdout callback (#479)
first pass at stdout callback for the most part, went pretty smoothly. aside from the code here, here are some comments/observations. 1. should somehow default to stdouthandler so i dont have to do ``` from langchain.callbacks import get_callback_manager from langchain.callbacks.stdout import StdOutCallbackHandler get_callback_manager().add_handler(StdOutCallbackHandler()) ``` 2. I kept around the verbosity flag. 1) this is pretty important for getting the stdout to look good for agents (and other things). 2) I actually added this for LLM class since it didn't have it. 3. The only part that isn't basically perfectly moved over is the end of the agent run. Here's a screenshot of the new stdout tracing ![Screen Shot 2022-12-29 at 4 03 50 PM](https://user-images.githubusercontent.com/11986836/210011538-6a74551a-2e61-437b-98d3-674212dede56.png) Noticing it is missing logging of the final thought, eg before this is what it looked like ![Screen Shot 2022-12-29 at 4 13 07 PM](https://user-images.githubusercontent.com/11986836/210011635-de68b3f5-e2b0-4cd3-9f1a-3afe970a8716.png) The reason its missing is that this was previously logged as part of agent end (lines 205 and 206) this is probably only relevant for the std out logger? any thoughts for how to get it back in?harrison/callback-updates
parent
36922318d3
commit
5d43246694
@ -1,8 +1,15 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
||||
|
||||
|
||||
def set_default_callback_manager() -> None:
|
||||
"""Set default callback manager."""
|
||||
callback = get_callback_manager()
|
||||
callback.add_handler(StdOutCallbackHandler())
|
||||
|
@ -0,0 +1,69 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print("Prompts after formatting:")
|
||||
for prompt in prompts:
|
||||
print_text(prompt, color="green", end="\n")
|
||||
|
||||
def on_llm_end(self, response: LLMResult) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
color: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
print_text(action.log, color=color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
if output != AGENT_FINISH_OBSERVATION:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
Loading…
Reference in New Issue