rfc: callback changes (#1165)

conceptually, no reason a tool should know what an "agent action" is

unless any objections, can change in all callback handlers
This commit is contained in:
Harrison Chase 2023-02-20 22:54:15 -08:00 committed by GitHub
parent fb83cd4ff4
commit b7708bbec6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 131 additions and 66 deletions

View File

@ -407,6 +407,9 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output return output
self.callback_manager.on_agent_action(
output, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool # Otherwise we lookup the tool
if output.tool in name_to_tool_map: if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool] tool = name_to_tool_map[output.tool]
@ -415,7 +418,7 @@ class AgentExecutor(Chain, BaseModel):
llm_prefix = "" if return_direct else self.agent.llm_prefix llm_prefix = "" if return_direct else self.agent.llm_prefix
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = tool.run( observation = tool.run(
output, output.tool_input,
verbose=self.verbose, verbose=self.verbose,
color=color, color=color,
llm_prefix=llm_prefix, llm_prefix=llm_prefix,
@ -423,7 +426,7 @@ class AgentExecutor(Chain, BaseModel):
) )
else: else:
observation = InvalidTool().run( observation = InvalidTool().run(
output, output.tool_input,
verbose=self.verbose, verbose=self.verbose,
color=None, color=None,
llm_prefix="", llm_prefix="",
@ -451,6 +454,9 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
return output return output
self.callback_manager.on_agent_action(
output, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool # Otherwise we lookup the tool
if output.tool in name_to_tool_map: if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool] tool = name_to_tool_map[output.tool]
@ -459,7 +465,7 @@ class AgentExecutor(Chain, BaseModel):
llm_prefix = "" if return_direct else self.agent.llm_prefix llm_prefix = "" if return_direct else self.agent.llm_prefix
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = await tool.arun( observation = await tool.arun(
output, output.tool_input,
verbose=self.verbose, verbose=self.verbose,
color=color, color=color,
llm_prefix=llm_prefix, llm_prefix=llm_prefix,
@ -467,7 +473,7 @@ class AgentExecutor(Chain, BaseModel):
) )
else: else:
observation = await InvalidTool().arun( observation = await InvalidTool().arun(
output, output.tool_input,
verbose=self.verbose, verbose=self.verbose,
color=None, color=None,
llm_prefix="", llm_prefix="",

View File

@ -68,7 +68,7 @@ class BaseCallbackHandler(ABC):
@abstractmethod @abstractmethod
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any: ) -> Any:
"""Run when tool starts running.""" """Run when tool starts running."""
@ -86,6 +86,10 @@ class BaseCallbackHandler(ABC):
def on_text(self, text: str, **kwargs: Any) -> Any: def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text.""" """Run on arbitrary text."""
@abstractmethod
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
@abstractmethod @abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end.""" """Run on agent end."""
@ -203,7 +207,7 @@ class CallbackManager(BaseCallbackManager):
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
action: AgentAction, input_str: str,
verbose: bool = False, verbose: bool = False,
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
@ -211,7 +215,16 @@ class CallbackManager(BaseCallbackManager):
for handler in self.handlers: for handler in self.handlers:
if not handler.ignore_agent: if not handler.ignore_agent:
if verbose or handler.always_verbose: if verbose or handler.always_verbose:
handler.on_tool_start(serialized, action, **kwargs) handler.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None: def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""
@ -293,7 +306,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Run when chain errors.""" """Run when chain errors."""
async def on_tool_start( async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
@ -308,6 +321,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_text(self, text: str, **kwargs: Any) -> None: async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text.""" """Run on arbitrary text."""
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end.""" """Run on agent end."""
@ -452,7 +468,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_tool_start( async def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
action: AgentAction, input_str: str,
verbose: bool = False, verbose: bool = False,
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
@ -461,12 +477,12 @@ class AsyncCallbackManager(BaseCallbackManager):
if not handler.ignore_agent: if not handler.ignore_agent:
if verbose or handler.always_verbose: if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_start): if asyncio.iscoroutinefunction(handler.on_tool_start):
await handler.on_tool_start(serialized, action, **kwargs) await handler.on_tool_start(serialized, input_str, **kwargs)
else: else:
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
None, None,
functools.partial( functools.partial(
handler.on_tool_start, serialized, action, **kwargs handler.on_tool_start, serialized, input_str, **kwargs
), ),
) )
@ -514,6 +530,23 @@ class AsyncCallbackManager(BaseCallbackManager):
None, functools.partial(handler.on_text, text, **kwargs) None, functools.partial(handler.on_text, text, **kwargs)
) )
async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent action."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_agent_action):
await handler.on_agent_action(action, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_action, action, **kwargs
),
)
async def on_agent_finish( async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None: ) -> None:

View File

@ -58,8 +58,7 @@ class OpenAICallbackHandler(BaseCallbackHandler):
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
action: AgentAction, input_str: str,
color: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Print out the log in specified color.""" """Print out the log in specified color."""
@ -92,6 +91,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""Run when agent ends.""" """Run when agent ends."""
pass pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish( def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None: ) -> None:

View File

@ -78,11 +78,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
self._callback_manager.on_chain_error(error, **kwargs) self._callback_manager.on_chain_error(error, **kwargs)
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
with self._lock: with self._lock:
self._callback_manager.on_tool_start(serialized, action, **kwargs) self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
with self._lock:
self._callback_manager.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None: def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""

View File

@ -53,11 +53,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
action: AgentAction, input_str: str,
color: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Print out the log in specified color.""" """Do nothing."""
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color if color else self.color) print_text(action.log, color=color if color else self.color)
def on_tool_end( def on_tool_end(

View File

@ -41,10 +41,14 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Run when chain errors.""" """Run when chain errors."""
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None: def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""

View File

@ -52,10 +52,14 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
action: AgentAction, input_str: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Print out the log in specified color.""" """Print out the log in specified color."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
# st.write requires two spaces before a newline to render it # st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n")) st.markdown(action.log.replace("\n", " \n"))

View File

@ -199,7 +199,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._end_trace() self._end_trace()
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Start a trace for a tool run.""" """Start a trace for a tool run."""
if self._session is None: if self._session is None:
@ -209,8 +209,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
tool_run = ToolRun( tool_run = ToolRun(
serialized=serialized, serialized=serialized,
action=action.tool, # TODO: this is duplicate info as above, not needed.
tool_input=action.tool_input, action=str(serialized),
tool_input=input_str,
extra=kwargs, extra=kwargs,
start_time=datetime.utcnow(), start_time=datetime.utcnow(),
execution_order=self._execution_order, execution_order=self._execution_order,
@ -250,6 +251,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Handle an agent finish message.""" """Handle an agent finish message."""
pass pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
class Tracer(BaseTracer, ABC): class Tracer(BaseTracer, ABC):
"""A non-thread safe implementation of the BaseTracer interface.""" """A non-thread safe implementation of the BaseTracer interface."""

View File

@ -7,7 +7,6 @@ from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import AgentAction
class BaseTool(BaseModel): class BaseTool(BaseModel):
@ -45,12 +44,11 @@ class BaseTool(BaseModel):
def __call__(self, tool_input: str) -> str: def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input.""" """Make tools callable with str input."""
agent_action = AgentAction(tool_input=tool_input, tool=self.name, log="") return self.run(tool_input)
return self.run(agent_action)
def run( def run(
self, self,
action: AgentAction, tool_input: str,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
@ -61,13 +59,13 @@ class BaseTool(BaseModel):
verbose = self.verbose verbose = self.verbose
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
action, tool_input,
verbose=verbose, verbose=verbose,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
observation = self._run(action.tool_input) observation = self._run(tool_input)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose) self.callback_manager.on_tool_error(e, verbose=verbose)
raise e raise e
@ -78,7 +76,7 @@ class BaseTool(BaseModel):
async def arun( async def arun(
self, self,
action: AgentAction, tool_input: str,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
@ -90,7 +88,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_start( await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
action, tool_input,
verbose=verbose, verbose=verbose,
color=start_color, color=start_color,
**kwargs, **kwargs,
@ -98,14 +96,14 @@ class BaseTool(BaseModel):
else: else:
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
action, tool_input,
verbose=verbose, verbose=verbose,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = await self._arun(action.tool_input) observation = await self._arun(tool_input)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose) await self.callback_manager.on_tool_error(e, verbose=verbose)

View File

@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run # 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3 assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2 assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1 assert handler.tool_starts == 2
assert handler.starts == 6 assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end # 1 extra agent end
assert handler.ends == 7 assert handler.ends == 7
assert handler.errors == 0 assert handler.errors == 0
@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run # 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3 assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2 assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1 assert handler.tool_starts == 2
assert handler.starts == 6 assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end # 1 extra agent end
assert handler.ends == 7 assert handler.ends == 7
assert handler.errors == 0 assert handler.errors == 0

View File

@ -2,7 +2,6 @@
import pytest import pytest
from langchain.agents.tools import Tool, tool from langchain.agents.tools import Tool, tool
from langchain.schema import AgentAction
def test_unnamed_decorator() -> None: def test_unnamed_decorator() -> None:
@ -101,7 +100,4 @@ async def test_create_async_tool() -> None:
assert test_tool.name == "test_name" assert test_tool.name == "test_name"
assert test_tool.description == "test_description" assert test_tool.description == "test_description"
assert test_tool.coroutine is not None assert test_tool.coroutine is not None
assert ( assert await test_tool.arun("foo") == "foo"
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
== "foo"
)

View File

@ -94,7 +94,7 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
self.errors += 1 self.errors += 1
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
self.tool_starts += 1 self.tool_starts += 1
@ -120,6 +120,11 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
self.agent_ends += 1 self.agent_ends += 1
self.ends += 1 self.ends += 1
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler): class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Fake async callback handler for testing.""" """Fake async callback handler for testing."""
@ -165,7 +170,7 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
self.errors += 1 self.errors += 1
async def on_tool_start( async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
self.tool_starts += 1 self.tool_starts += 1
@ -190,3 +195,8 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Run when agent ends running.""" """Run when agent ends running."""
self.agent_ends += 1 self.agent_ends += 1
self.ends += 1 self.ends += 1
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1

View File

@ -9,7 +9,7 @@ from langchain.callbacks.base import (
CallbackManager, CallbackManager,
) )
from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import ( from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler, BaseFakeCallbackHandler,
FakeAsyncCallbackHandler, FakeAsyncCallbackHandler,
@ -27,7 +27,7 @@ def _test_callback_manager(
manager.on_chain_start({"name": "foo"}, {}) manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({}) manager.on_chain_end({})
manager.on_chain_error(Exception()) manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", "")) manager.on_tool_start({}, "")
manager.on_tool_end("") manager.on_tool_end("")
manager.on_tool_error(Exception()) manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={})) manager.on_agent_finish(AgentFinish(log="", return_values={}))
@ -44,7 +44,7 @@ async def _test_callback_manager_async(
await manager.on_chain_start({"name": "foo"}, {}) await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({}) await manager.on_chain_end({})
await manager.on_chain_error(Exception()) await manager.on_chain_error(Exception())
await manager.on_tool_start({}, AgentAction("", "", "")) await manager.on_tool_start({}, "")
await manager.on_tool_end("") await manager.on_tool_end("")
await manager.on_tool_error(Exception()) await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={})) await manager.on_agent_finish(AgentFinish(log="", return_values={}))
@ -73,7 +73,7 @@ def _test_callback_manager_pass_in_verbose(
manager.on_chain_start({"name": "foo"}, {}, verbose=True) manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True) manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True) manager.on_chain_error(Exception(), verbose=True)
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True) manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True) manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True) manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True) manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
@ -136,7 +136,7 @@ def test_ignore_agent() -> None:
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True) handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True) handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2]) manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True) manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True) manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True) manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True) manager.on_agent_finish(AgentFinish({}, ""), verbose=True)

View File

@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
TracerSession, TracerSession,
) )
from langchain.callbacks.tracers.schemas import TracerSessionCreate from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.schema import AgentAction, LLMResult from langchain.schema import LLMResult
TEST_SESSION_ID = 2023 TEST_SESSION_ID = 2023
@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
serialized={}, serialized={},
tool_input="test", tool_input="test",
output="test", output="test",
action="action", action="{}",
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
error=None, error=None,
child_runs=[ child_runs=[
@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
def _perform_nested_run(tracer: BaseTracer) -> None: def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run.""" """Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={}) tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start( tracer.on_tool_start(serialized={}, input_str="test")
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]])) tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_end("test") tracer.on_tool_end("test")
@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
serialized={}, serialized={},
tool_input="test", tool_input="test",
output="test", output="test",
action="action", action="{}",
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
error=None, error=None,
) )
tracer = FakeTracer() tracer = FakeTracer()
tracer.new_session() tracer.new_session()
tracer.on_tool_start( tracer.on_tool_start(serialized={}, input_str="test")
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_end("test") tracer.on_tool_end("test")
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
serialized={}, serialized={},
tool_input="test", tool_input="test",
output=None, output=None,
action="action", action="{}",
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
error=repr(exception), error=repr(exception),
) )
tracer = FakeTracer() tracer = FakeTracer()
tracer.new_session() tracer.new_session()
tracer.on_tool_start( tracer.on_tool_start(serialized={}, input_str="test")
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_error(exception) tracer.on_tool_error(exception)
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
tracer.on_llm_end(response=LLMResult([[]])) tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]])) tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_start( tracer.on_tool_start(serialized={}, input_str="test")
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception) tracer.on_llm_error(exception)
tracer.on_tool_error(exception) tracer.on_tool_error(exception)
@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
error=repr(exception), error=repr(exception),
tool_input="test", tool_input="test",
output=None, output=None,
action="action", action="{}",
child_runs=[ child_runs=[
LLMRun( LLMRun(
id=None, id=None,