forked from Archives/langchain
WIP: stdout callback (#479)
first pass at stdout callback for the most part, went pretty smoothly. aside from the code here, here are some comments/observations. 1. should somehow default to stdouthandler so i dont have to do ``` from langchain.callbacks import get_callback_manager from langchain.callbacks.stdout import StdOutCallbackHandler get_callback_manager().add_handler(StdOutCallbackHandler()) ``` 2. I kept around the verbosity flag. 1) this is pretty important for getting the stdout to look good for agents (and other things). 2) I actually added this for LLM class since it didn't have it. 3. The only part that isn't basically perfectly moved over is the end of the agent run. Here's a screenshot of the new stdout tracing ![Screen Shot 2022-12-29 at 4 03 50 PM](https://user-images.githubusercontent.com/11986836/210011538-6a74551a-2e61-437b-98d3-674212dede56.png) Noticing it is missing logging of the final thought, eg before this is what it looked like ![Screen Shot 2022-12-29 at 4 13 07 PM](https://user-images.githubusercontent.com/11986836/210011635-de68b3f5-e2b0-4cd3-9f1a-3afe970a8716.png) The reason its missing is that this was previously logged as part of agent end (lines 205 and 206) this is probably only relevant for the std out logger? any thoughts for how to get it back in?
This commit is contained in:
parent
36922318d3
commit
5d43246694
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.cache import BaseCache
|
||||
from langchain.callbacks import set_default_callback_manager
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMBashChain,
|
||||
@ -33,6 +34,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
logger: BaseLogger = StdOutLogger()
|
||||
verbose: bool = False
|
||||
llm_cache: Optional[BaseCache] = None
|
||||
set_default_callback_manager()
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
|
@ -3,11 +3,10 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
import langchain
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.base import Chain
|
||||
@ -17,7 +16,7 @@ from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@ -47,7 +46,7 @@ class Agent(BaseModel):
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentFinish, AgentAction]:
|
||||
) -> AgentAction:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -74,7 +73,7 @@ class Agent(BaseModel):
|
||||
parsed_output = self._extract_tool_and_input(full_output)
|
||||
tool, tool_input = parsed_output
|
||||
if tool == self.finish_tool_name:
|
||||
return AgentFinish({"output": tool_input}, full_output)
|
||||
return AgentFinish(tool, tool_input, full_output, {"output": tool_input})
|
||||
return AgentAction(tool, tool_input, full_output)
|
||||
|
||||
def prepare_for_new_call(self) -> None:
|
||||
@ -138,11 +137,15 @@ class Agent(BaseModel):
|
||||
llm: BaseLLM,
|
||||
tools: List[Tool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager
|
||||
llm=llm,
|
||||
prompt=cls.create_prompt(tools),
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
)
|
||||
return cls(llm_chain=llm_chain)
|
||||
|
||||
@ -217,28 +220,34 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.verbose:
|
||||
langchain.logger.log_agent_end(output, color="green")
|
||||
self._get_callback_manager().on_tool_start(
|
||||
{"name": "Finish"}, output, color="green"
|
||||
)
|
||||
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
if self.verbose:
|
||||
langchain.logger.log_agent_action(output, color="green")
|
||||
|
||||
# And then we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
chain = name_to_tool_map[output.tool]
|
||||
self._get_callback_manager().on_tool_start(
|
||||
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
|
||||
)
|
||||
if self.verbose:
|
||||
self._get_callback_manager().on_tool_start(
|
||||
{"name": str(chain)[:60] + "..."}, output, color="green"
|
||||
)
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = chain(output.tool_input)
|
||||
self._get_callback_manager().on_tool_end(observation)
|
||||
color = color_mapping[output.tool]
|
||||
else:
|
||||
if self.verbose:
|
||||
self._get_callback_manager().on_tool_start(
|
||||
{"name": "N/A"}, output, color="green"
|
||||
)
|
||||
observation = f"{output.tool} is not a valid tool, try another one."
|
||||
color = None
|
||||
if self.verbose:
|
||||
langchain.logger.log_agent_observation(
|
||||
self._get_callback_manager().on_tool_end(
|
||||
observation,
|
||||
color=color,
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
|
@ -21,6 +21,7 @@ def initialize_agent(
|
||||
llm: BaseLLM,
|
||||
agent: str = "zero-shot-react-description",
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Load agent given tools and LLM.
|
||||
@ -30,6 +31,7 @@ def initialize_agent(
|
||||
llm: Language model to use as the agent.
|
||||
agent: The agent to use. Valid options are:
|
||||
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
||||
verbose: Whether to use the callback manager for this particular agent.
|
||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||
not provided. Defaults to None.
|
||||
**kwargs: Additional key word arguments to pass to the agent.
|
||||
@ -44,8 +46,12 @@ def initialize_agent(
|
||||
)
|
||||
agent_cls = AGENT_TO_CLASS[agent]
|
||||
agent_obj = agent_cls.from_llm_and_tools(
|
||||
llm, tools, callback_manager=callback_manager
|
||||
llm, tools, callback_manager=callback_manager, verbose=verbose
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -1,8 +1,15 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
||||
|
||||
|
||||
def set_default_callback_manager() -> None:
|
||||
"""Set default callback manager."""
|
||||
callback = get_callback_manager()
|
||||
callback.add_handler(StdOutCallbackHandler())
|
||||
|
@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
@ -11,7 +11,7 @@ class BaseCallbackHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@ -28,7 +28,7 @@ class BaseCallbackHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
@ -42,12 +42,12 @@ class BaseCallbackHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
@abstractmethod
|
||||
@ -75,11 +75,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
self.handlers = handlers
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_llm_start(serialized, prompts, **extra)
|
||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
@ -95,11 +95,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
handler.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_chain_start(serialized, inputs, **extra)
|
||||
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
@ -112,16 +112,16 @@ class CallbackManager(BaseCallbackManager):
|
||||
handler.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_tool_start(serialized, action, tool_input, **extra)
|
||||
handler.on_tool_start(serialized, action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
handler.on_tool_end(output)
|
||||
handler.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
@ -8,7 +8,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
@ -35,11 +35,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
_callback_manager: CallbackManager = CallbackManager([])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **extra)
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
@ -55,11 +55,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
self._callback_manager.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **extra)
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
@ -72,18 +72,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
self._callback_manager.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(
|
||||
serialized, action, tool_input, **extra
|
||||
)
|
||||
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_end(output)
|
||||
self._callback_manager.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
69
langchain/callbacks/stdout.py
Normal file
69
langchain/callbacks/stdout.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print("Prompts after formatting:")
|
||||
for prompt in prompts:
|
||||
print_text(prompt, color="green", end="\n")
|
||||
|
||||
def on_llm_end(self, response: LLMResult) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
color: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
print_text(action.log, color=color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
if output != AGENT_FINISH_OBSERVATION:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
@ -120,16 +120,12 @@ class Chain(BaseModel, ABC):
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
print(
|
||||
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
||||
self._get_callback_manager().on_chain_start(
|
||||
{"name": self.__class__.__name__}, inputs
|
||||
)
|
||||
self._get_callback_manager().on_chain_start(
|
||||
{"name": self.__class__.__name__}, inputs
|
||||
)
|
||||
outputs = self._call(inputs)
|
||||
self._get_callback_manager().on_chain_end(outputs)
|
||||
if self.verbose:
|
||||
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
||||
self._get_callback_manager().on_chain_end(outputs)
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
|
@ -3,7 +3,6 @@ from typing import Any, Dict, List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
import langchain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@ -54,6 +53,7 @@ class LLMChain(Chain, BaseModel):
|
||||
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
self.llm.verbose = self.verbose
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
@ -61,8 +61,6 @@ class LLMChain(Chain, BaseModel):
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
if self.verbose:
|
||||
langchain.logger.log_llm_inputs(selected_inputs, prompt)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@ -78,8 +76,6 @@ class LLMChain(Chain, BaseModel):
|
||||
for generation in response.generations:
|
||||
# Get the text of the top generated string.
|
||||
response_str = generation[0].text
|
||||
if self.verbose:
|
||||
langchain.logger.log_llm_response(response_str)
|
||||
outputs.append({self.output_key: response_str})
|
||||
return outputs
|
||||
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
@ -13,10 +13,16 @@ from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
class Config:
|
||||
@ -48,11 +54,13 @@ class BaseLLM(BaseModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
self._get_callback_manager().on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
if self.verbose:
|
||||
self._get_callback_manager().on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
output = self._generate(prompts, stop=stop)
|
||||
self._get_callback_manager().on_llm_end(output)
|
||||
if self.verbose:
|
||||
self._get_callback_manager().on_llm_end(output)
|
||||
return output
|
||||
params = self._llm_dict()
|
||||
params["stop"] = stop
|
||||
|
@ -1,9 +1,13 @@
|
||||
"""Common schema objects."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
AGENT_FINISH_OBSERVATION = "__agent_finish__"
|
||||
|
||||
class AgentAction(NamedTuple):
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Agent's action to take."""
|
||||
|
||||
tool: str
|
||||
@ -11,11 +15,11 @@ class AgentAction(NamedTuple):
|
||||
log: str
|
||||
|
||||
|
||||
class AgentFinish(NamedTuple):
|
||||
@dataclass
|
||||
class AgentFinish(AgentAction):
|
||||
"""Agent's return value."""
|
||||
|
||||
return_values: dict
|
||||
log: str
|
||||
|
||||
|
||||
class Generation(NamedTuple):
|
||||
|
@ -74,7 +74,7 @@ def test_agent_with_callbacks() -> None:
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
]
|
||||
@ -89,7 +89,36 @@ def test_agent_with_callbacks() -> None:
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 6
|
||||
assert handler.ends == 6
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending
|
||||
assert handler.starts == 7
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_agent_with_callbacks_not_verbose() -> None:
|
||||
"""Test react chain with callbacks but not verbose."""
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
@ -2,7 +2,7 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
@ -15,7 +15,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
self.errors = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.starts += 1
|
||||
@ -32,7 +32,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.starts += 1
|
||||
@ -46,12 +46,12 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.ends += 1
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@ -13,10 +13,10 @@ def _test_callback_manager(
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({}, {})
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, "", "")
|
||||
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
for handler in handlers:
|
||||
|
@ -140,9 +140,20 @@ def test_run_arg_with_memory() -> None:
|
||||
def test_run_with_callback() -> None:
|
||||
"""Test run method works when callback manager is passed."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager([handler]))
|
||||
chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True)
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_run_with_callback_not_verbose() -> None:
|
||||
"""Test run method works when callback manager is passed and not verbose."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager([handler]))
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
@ -7,9 +7,20 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
def test_llm_with_callbacks() -> None:
|
||||
"""Test LLM callbacks."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager([handler]))
|
||||
llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True)
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_llm_with_callbacks_not_verbose() -> None:
|
||||
"""Test LLM callbacks but not verbose."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager([handler]))
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
Loading…
Reference in New Issue
Block a user