WIP: stdout callback (#479)

first pass at stdout callback

for the most part, went pretty smoothly. aside from the code here, here
are some comments/observations.

1. 
should somehow default to stdouthandler so i dont have to do 
```
from langchain.callbacks import get_callback_manager
from langchain.callbacks.stdout import StdOutCallbackHandler

get_callback_manager().add_handler(StdOutCallbackHandler())
```

2. I kept around the verbosity flag. 1) this is pretty important for
getting the stdout to look good for agents (and other things). 2) I
actually added this for LLM class since it didn't have it.

3. The only part that isn't basically perfectly moved over is the end of
the agent run. Here's a screenshot of the new stdout tracing
![Screen Shot 2022-12-29 at 4 03 50
PM](https://user-images.githubusercontent.com/11986836/210011538-6a74551a-2e61-437b-98d3-674212dede56.png)

Noticing it is missing logging of the final thought, eg before this is
what it looked like
![Screen Shot 2022-12-29 at 4 13 07
PM](https://user-images.githubusercontent.com/11986836/210011635-de68b3f5-e2b0-4cd3-9f1a-3afe970a8716.png)

The reason its missing is that this was previously logged as part of
agent end (lines 205 and 206)

this is probably only relevant for the std out logger? any thoughts for
how to get it back in?
harrison/callback-updates
Harrison Chase 1 year ago committed by GitHub
parent 36922318d3
commit 5d43246694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks import set_default_callback_manager
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -33,6 +34,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
logger: BaseLogger = StdOutLogger()
verbose: bool = False
llm_cache: Optional[BaseCache] = None
set_default_callback_manager()
__all__ = [
"LLMChain",

@ -3,11 +3,10 @@ from __future__ import annotations
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain
@ -17,7 +16,7 @@ from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction, AgentFinish
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish
logger = logging.getLogger()
@ -47,7 +46,7 @@ class Agent(BaseModel):
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentFinish, AgentAction]:
) -> AgentAction:
"""Given input, decided what to do.
Args:
@ -74,7 +73,7 @@ class Agent(BaseModel):
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
if tool == self.finish_tool_name:
return AgentFinish({"output": tool_input}, full_output)
return AgentFinish(tool, tool_input, full_output, {"output": tool_input})
return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self) -> None:
@ -138,11 +137,15 @@ class Agent(BaseModel):
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(
llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager
llm=llm,
prompt=cls.create_prompt(tools),
callback_manager=callback_manager,
verbose=verbose,
)
return cls(llm_chain=llm_chain)
@ -217,28 +220,34 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
langchain.logger.log_agent_end(output, color="green")
self._get_callback_manager().on_tool_start(
{"name": "Finish"}, output, color="green"
)
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
if self.verbose:
langchain.logger.log_agent_action(output, color="green")
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
self._get_callback_manager().on_tool_start(
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
)
if self.verbose:
self._get_callback_manager().on_tool_start(
{"name": str(chain)[:60] + "..."}, output, color="green"
)
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
self._get_callback_manager().on_tool_end(observation)
color = color_mapping[output.tool]
else:
if self.verbose:
self._get_callback_manager().on_tool_start(
{"name": "N/A"}, output, color="green"
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
if self.verbose:
langchain.logger.log_agent_observation(
self._get_callback_manager().on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,

@ -21,6 +21,7 @@ def initialize_agent(
llm: BaseLLM,
agent: str = "zero-shot-react-description",
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> AgentExecutor:
"""Load agent given tools and LLM.
@ -30,6 +31,7 @@ def initialize_agent(
llm: Language model to use as the agent.
agent: The agent to use. Valid options are:
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
verbose: Whether to use the callback manager for this particular agent.
callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
**kwargs: Additional key word arguments to pass to the agent.
@ -44,8 +46,12 @@ def initialize_agent(
)
agent_cls = AGENT_TO_CLASS[agent]
agent_obj = agent_cls.from_llm_and_tools(
llm, tools, callback_manager=callback_manager
llm, tools, callback_manager=callback_manager, verbose=verbose
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**kwargs,
)

@ -1,8 +1,15 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()
def set_default_callback_manager() -> None:
"""Set default callback manager."""
callback = get_callback_manager()
callback.add_handler(StdOutCallbackHandler())

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain.schema import LLMResult
from langchain.schema import AgentAction, LLMResult
class BaseCallbackHandler(ABC):
@ -11,7 +11,7 @@ class BaseCallbackHandler(ABC):
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
@ -28,7 +28,7 @@ class BaseCallbackHandler(ABC):
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
@ -42,12 +42,12 @@ class BaseCallbackHandler(ABC):
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str) -> None:
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
@abstractmethod
@ -75,11 +75,11 @@ class CallbackManager(BaseCallbackManager):
self.handlers = handlers
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
handler.on_llm_start(serialized, prompts, **extra)
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(
self,
@ -95,11 +95,11 @@ class CallbackManager(BaseCallbackManager):
handler.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
handler.on_chain_start(serialized, inputs, **extra)
handler.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
@ -112,16 +112,16 @@ class CallbackManager(BaseCallbackManager):
handler.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
handler.on_tool_start(serialized, action, tool_input, **extra)
handler.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str) -> None:
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
handler.on_tool_end(output)
handler.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""

@ -8,7 +8,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import LLMResult
from langchain.schema import AgentAction, LLMResult
class Singleton:
@ -35,11 +35,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
_callback_manager: CallbackManager = CallbackManager([])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **extra)
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(
self,
@ -55,11 +55,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
self._callback_manager.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **extra)
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
@ -72,18 +72,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
self._callback_manager.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(
serialized, action, tool_input, **extra
)
self._callback_manager.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str) -> None:
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output)
self._callback_manager.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""

@ -0,0 +1,69 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print("Prompts after formatting:")
for prompt in prompts:
print_text(prompt, color="green", end="\n")
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
print_text(action.log, color=color)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if output != AGENT_FINISH_OBSERVATION:
print_text(f"\n{observation_prefix}")
print_text(output, color=color)
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass

@ -120,16 +120,12 @@ class Chain(BaseModel, ABC):
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
print(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
self._get_callback_manager().on_chain_start(
{"name": self.__class__.__name__}, inputs
)
self._get_callback_manager().on_chain_start(
{"name": self.__class__.__name__}, inputs
)
outputs = self._call(inputs)
self._get_callback_manager().on_chain_end(outputs)
if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
self._get_callback_manager().on_chain_end(outputs)
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)

@ -3,7 +3,6 @@ from typing import Any, Dict, List, Sequence, Union
from pydantic import BaseModel, Extra
import langchain
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -54,6 +53,7 @@ class LLMChain(Chain, BaseModel):
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
self.llm.verbose = self.verbose
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
@ -61,8 +61,6 @@ class LLMChain(Chain, BaseModel):
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
if self.verbose:
langchain.logger.log_llm_inputs(selected_inputs, prompt)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
@ -78,8 +76,6 @@ class LLMChain(Chain, BaseModel):
for generation in response.generations:
# Get the text of the top generated string.
response_str = generation[0].text
if self.verbose:
langchain.logger.log_llm_response(response_str)
outputs.append({self.output_key: response_str})
return outputs

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
import yaml
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, Field
import langchain
from langchain.callbacks import get_callback_manager
@ -13,10 +13,16 @@ from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import Generation, LLMResult
def _get_verbosity() -> bool:
return langchain.verbose
class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: Optional[BaseCallbackManager] = None
class Config:
@ -48,11 +54,13 @@ class BaseLLM(BaseModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
self._get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, prompts
)
if self.verbose:
self._get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
self._get_callback_manager().on_llm_end(output)
if self.verbose:
self._get_callback_manager().on_llm_end(output)
return output
params = self._llm_dict()
params["stop"] = stop

@ -1,9 +1,13 @@
"""Common schema objects."""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
AGENT_FINISH_OBSERVATION = "__agent_finish__"
class AgentAction(NamedTuple):
@dataclass
class AgentAction:
"""Agent's action to take."""
tool: str
@ -11,11 +15,11 @@ class AgentAction(NamedTuple):
log: str
class AgentFinish(NamedTuple):
@dataclass
class AgentFinish(AgentAction):
"""Agent's return value."""
return_values: dict
log: str
class Generation(NamedTuple):

@ -74,7 +74,7 @@ def test_agent_with_callbacks() -> None:
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
]
@ -89,7 +89,36 @@ def test_agent_with_callbacks() -> None:
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending
assert handler.starts == 7
assert handler.ends == 7
assert handler.errors == 0
def test_agent_with_callbacks_not_verbose() -> None:
"""Test react chain with callbacks but not verbose."""
handler = FakeCallbackHandler()
manager = CallbackManager([handler])
tool = "Search"
responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
callback_manager=manager,
)
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
assert handler.ends == 6
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

@ -2,7 +2,7 @@
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
from langchain.schema import AgentAction, LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
@ -15,7 +15,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.starts += 1
@ -32,7 +32,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.starts += 1
@ -46,12 +46,12 @@ class FakeCallbackHandler(BaseCallbackHandler):
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.ends += 1

@ -2,7 +2,7 @@
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import LLMResult
from langchain.schema import AgentAction, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -13,10 +13,10 @@ def _test_callback_manager(
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({}, {})
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, "", "")
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_end("")
manager.on_tool_error(Exception())
for handler in handlers:

@ -140,9 +140,20 @@ def test_run_arg_with_memory() -> None:
def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager([handler]))
chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True)
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_run_with_callback_not_verbose() -> None:
"""Test run method works when callback manager is passed and not verbose."""
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager([handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

@ -7,9 +7,20 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager([handler]))
llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True)
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_llm_with_callbacks_not_verbose() -> None:
"""Test LLM callbacks but not verbose."""
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager([handler]))
output = llm("foo")
assert output == "foo"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

Loading…
Cancel
Save