allow for optional CallbackManager in LLM, Chain, and Agent (#482)

harrison/callback-updates
Ankush Gola 1 year ago committed by GitHub
parent 46b31626b5
commit 36922318d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@ from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.tools import Tool
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
@ -133,10 +133,17 @@ class Agent(BaseModel):
pass
@classmethod
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent:
def from_llm_and_tools(
cls,
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
llm_chain = LLMChain(
llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager
)
return cls(llm_chain=llm_chain)
def return_stopped_response(self) -> dict:
@ -154,10 +161,16 @@ class AgentExecutor(Chain, BaseModel):
@classmethod
def from_agent_and_tools(
cls, agent: Agent, tools: List[Tool], **kwargs: Any
cls,
agent: Agent,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Create from agent and tools."""
return cls(agent=agent, tools=tools, **kwargs)
return cls(
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
)
@property
def input_keys(self) -> List[str]:
@ -214,12 +227,12 @@ class AgentExecutor(Chain, BaseModel):
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
get_callback_manager().on_tool_start(
self._get_callback_manager().on_tool_start(
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
)
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
get_callback_manager().on_tool_end(observation)
self._get_callback_manager().on_tool_end(observation)
color = color_mapping[output.tool]
else:
observation = f"{output.tool} is not a valid tool, try another one."

@ -1,11 +1,12 @@
"""Load agent."""
from typing import Any, List
from typing import Any, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = {
@ -19,6 +20,7 @@ def initialize_agent(
tools: List[Tool],
llm: BaseLLM,
agent: str = "zero-shot-react-description",
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Load agent given tools and LLM.
@ -28,6 +30,8 @@ def initialize_agent(
llm: Language model to use as the agent.
agent: The agent to use. Valid options are:
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
**kwargs: Additional key word arguments to pass to the agent.
Returns:
@ -39,5 +43,9 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}."
)
agent_cls = AGENT_TO_CLASS[agent]
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs)
agent_obj = agent_cls.from_llm_and_tools(
llm, tools, callback_manager=callback_manager
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs
)

@ -6,6 +6,7 @@ from pydantic import BaseModel, Extra, Field
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class Memory(BaseModel, ABC):
@ -43,9 +44,21 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[Memory] = None
callback_manager: Optional[BaseCallbackManager] = None
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
@property
@abstractmethod
@ -110,9 +123,11 @@ class Chain(BaseModel, ABC):
print(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
)
get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs)
self._get_callback_manager().on_chain_start(
{"name": self.__class__.__name__}, inputs
)
outputs = self._call(inputs)
get_callback_manager().on_chain_end(outputs)
self._get_callback_manager().on_chain_end(outputs)
if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
self._validate_outputs(outputs)

@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import Generation, LLMResult
@ -16,11 +17,13 @@ class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None
callback_manager: Optional[BaseCallbackManager] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@abstractmethod
def _generate(
@ -28,6 +31,12 @@ class BaseLLM(BaseModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
@ -39,11 +48,11 @@ class BaseLLM(BaseModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
get_callback_manager().on_llm_start(
self._get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
get_callback_manager().on_llm_end(output)
self._get_callback_manager().on_llm_end(output)
return output
params = self._llm_dict()
params["stop"] = stop
@ -58,11 +67,11 @@ class BaseLLM(BaseModel, ABC):
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
get_callback_manager().on_llm_start(
self._get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
new_results = self._generate(missing_prompts, stop=stop)
get_callback_manager().on_llm_end(new_results)
self._get_callback_manager().on_llm_end(new_results)
for i, result in enumerate(new_results.generations):
existing_prompts[i] = result
prompt = prompts[i]

@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.agents import Tool, initialize_agent
from langchain.agents import AgentExecutor, Tool, initialize_agent
from langchain.callbacks.base import CallbackManager
from langchain.llms.base import LLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeListLLM(LLM, BaseModel):
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
return "fake_list"
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
def _get_agent(**kwargs: Any) -> AgentExecutor:
"""Get agent for testing."""
bad_action_name = "BadAction"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
@ -44,30 +46,50 @@ def test_agent_bad_action() -> None:
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
]
agent = initialize_agent(
tools, fake_llm, agent="zero-shot-react-description", verbose=True
tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
)
return agent
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
agent = _get_agent()
output = agent.run("when was langchain made")
assert output == "curses foiled again"
def test_agent_stopped_early() -> None:
"""Test react chain when bad action given."""
bad_action_name = "BadAction"
agent = _get_agent(max_iterations=0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
def test_agent_with_callbacks() -> None:
"""Test react chain with callbacks."""
handler = FakeCallbackHandler()
manager = CallbackManager([handler])
tool = "Search"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses)
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
verbose=True,
max_iterations=0,
callback_manager=manager,
)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
assert handler.ends == 6
assert handler.errors == 0

@ -0,0 +1,60 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
def __init__(self) -> None:
"""Initialize the mock callback handler."""
self.starts = 0
self.ends = 0
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1

@ -1,69 +1,9 @@
"""Test CallbackManager."""
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
def __init__(self) -> None:
"""Initialize the mock callback handler."""
self.starts = 0
self.ends = 0
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def _test_callback_manager(

@ -4,7 +4,9 @@ from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from langchain.callbacks.base import CallbackManager
from langchain.chains.base import Chain, Memory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeMemory(Memory, BaseModel):
@ -133,3 +135,14 @@ def test_run_arg_with_memory() -> None:
"""Test run method works when arg is passed."""
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
chain.run("bar")
def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager([handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0

@ -0,0 +1,15 @@
"""Test LLM callbacks."""
from langchain.callbacks.base import CallbackManager
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager([handler]))
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
Loading…
Cancel
Save