rfc: callback changes (#1165)

conceptually, no reason a tool should know what an "agent action" is

unless any objections, can change in all callback handlers
This commit is contained in:
Harrison Chase 2023-02-20 22:54:15 -08:00 committed by GitHub
parent fb83cd4ff4
commit b7708bbec6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 131 additions and 66 deletions

View File

@ -407,6 +407,9 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
self.callback_manager.on_agent_action(
output, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
@ -415,7 +418,7 @@ class AgentExecutor(Chain, BaseModel):
llm_prefix = "" if return_direct else self.agent.llm_prefix
# We then call the tool on the tool input to get an observation
observation = tool.run(
output,
output.tool_input,
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
@ -423,7 +426,7 @@ class AgentExecutor(Chain, BaseModel):
)
else:
observation = InvalidTool().run(
output,
output.tool_input,
verbose=self.verbose,
color=None,
llm_prefix="",
@ -451,6 +454,9 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
self.callback_manager.on_agent_action(
output, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
@ -459,7 +465,7 @@ class AgentExecutor(Chain, BaseModel):
llm_prefix = "" if return_direct else self.agent.llm_prefix
# We then call the tool on the tool input to get an observation
observation = await tool.arun(
output,
output.tool_input,
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
@ -467,7 +473,7 @@ class AgentExecutor(Chain, BaseModel):
)
else:
observation = await InvalidTool().arun(
output,
output.tool_input,
verbose=self.verbose,
color=None,
llm_prefix="",

View File

@ -68,7 +68,7 @@ class BaseCallbackHandler(ABC):
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""
@ -86,6 +86,10 @@ class BaseCallbackHandler(ABC):
def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""
@abstractmethod
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
@abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
@ -203,7 +207,7 @@ class CallbackManager(BaseCallbackManager):
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
@ -211,7 +215,16 @@ class CallbackManager(BaseCallbackManager):
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_start(serialized, action, **kwargs)
handler.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when tool ends running."""
@ -293,7 +306,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Run when chain errors."""
async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
@ -308,6 +321,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
@ -452,7 +468,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
@ -461,12 +477,12 @@ class AsyncCallbackManager(BaseCallbackManager):
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_start):
await handler.on_tool_start(serialized, action, **kwargs)
await handler.on_tool_start(serialized, input_str, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_tool_start, serialized, action, **kwargs
handler.on_tool_start, serialized, input_str, **kwargs
),
)
@ -514,6 +530,23 @@ class AsyncCallbackManager(BaseCallbackManager):
None, functools.partial(handler.on_text, text, **kwargs)
)
async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent action."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_agent_action):
await handler.on_agent_action(action, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_action, action, **kwargs
),
)
async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:

View File

@ -58,8 +58,7 @@ class OpenAICallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
@ -92,6 +91,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""Run when agent ends."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:

View File

@ -78,11 +78,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
self._callback_manager.on_chain_error(error, **kwargs)
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(serialized, action, **kwargs)
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
with self._lock:
self._callback_manager.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""

View File

@ -53,11 +53,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
"""Do nothing."""
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color if color else self.color)
def on_tool_end(

View File

@ -41,10 +41,14 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Run when chain errors."""
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""

View File

@ -52,10 +52,14 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
# st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))

View File

@ -199,7 +199,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._end_trace()
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Start a trace for a tool run."""
if self._session is None:
@ -209,8 +209,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
tool_run = ToolRun(
serialized=serialized,
action=action.tool,
tool_input=action.tool_input,
# TODO: this is duplicate info as above, not needed.
action=str(serialized),
tool_input=input_str,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
@ -250,6 +251,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Handle an agent finish message."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
class Tracer(BaseTracer, ABC):
"""A non-thread safe implementation of the BaseTracer interface."""

View File

@ -7,7 +7,6 @@ from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import AgentAction
class BaseTool(BaseModel):
@ -45,12 +44,11 @@ class BaseTool(BaseModel):
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
agent_action = AgentAction(tool_input=tool_input, tool=self.name, log="")
return self.run(agent_action)
return self.run(tool_input)
def run(
self,
action: AgentAction,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
@ -61,13 +59,13 @@ class BaseTool(BaseModel):
verbose = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
action,
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
observation = self._run(action.tool_input)
observation = self._run(tool_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose)
raise e
@ -78,7 +76,7 @@ class BaseTool(BaseModel):
async def arun(
self,
action: AgentAction,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
@ -90,7 +88,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
action,
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
@ -98,14 +96,14 @@ class BaseTool(BaseModel):
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
action,
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(action.tool_input)
observation = await self._arun(tool_input)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose)

View File

@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
assert handler.tool_starts == 2
assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
assert handler.tool_starts == 2
assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0

View File

@ -2,7 +2,6 @@
import pytest
from langchain.agents.tools import Tool, tool
from langchain.schema import AgentAction
def test_unnamed_decorator() -> None:
@ -101,7 +100,4 @@ async def test_create_async_tool() -> None:
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.coroutine is not None
assert (
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
== "foo"
)
assert await test_tool.arun("foo") == "foo"

View File

@ -94,7 +94,7 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
@ -120,6 +120,11 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
self.agent_ends += 1
self.ends += 1
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Fake async callback handler for testing."""
@ -165,7 +170,7 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
self.errors += 1
async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
@ -190,3 +195,8 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1

View File

@ -9,7 +9,7 @@ from langchain.callbacks.base import (
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
@ -27,7 +27,7 @@ def _test_callback_manager(
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_start({}, "")
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
@ -44,7 +44,7 @@ async def _test_callback_manager_async(
await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({})
await manager.on_chain_error(Exception())
await manager.on_tool_start({}, AgentAction("", "", ""))
await manager.on_tool_start({}, "")
await manager.on_tool_end("")
await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
@ -73,7 +73,7 @@ def _test_callback_manager_pass_in_verbose(
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True)
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
@ -136,7 +136,7 @@ def test_ignore_agent() -> None:
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)

View File

@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.schema import AgentAction, LLMResult
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
serialized={},
tool_input="test",
output="test",
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
child_runs=[
@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_end("test")
@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
serialized={},
tool_input="test",
output="test",
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_end("test")
assert tracer.runs == [compare_run]
@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
serialized={},
tool_input="test",
output=None,
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_error(exception)
assert tracer.runs == [compare_run]
@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
error=repr(exception),
tool_input="test",
output=None,
action="action",
action="{}",
child_runs=[
LLMRun(
id=None,