mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
rfc: callback changes (#1165)
conceptually, no reason a tool should know what an "agent action" is unless any objections, can change in all callback handlers
This commit is contained in:
parent
fb83cd4ff4
commit
b7708bbec6
@ -407,6 +407,9 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
self.callback_manager.on_agent_action(
|
||||
output, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[output.tool]
|
||||
@ -415,7 +418,7 @@ class AgentExecutor(Chain, BaseModel):
|
||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
output,
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
llm_prefix=llm_prefix,
|
||||
@ -423,7 +426,7 @@ class AgentExecutor(Chain, BaseModel):
|
||||
)
|
||||
else:
|
||||
observation = InvalidTool().run(
|
||||
output,
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
llm_prefix="",
|
||||
@ -451,6 +454,9 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
self.callback_manager.on_agent_action(
|
||||
output, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[output.tool]
|
||||
@ -459,7 +465,7 @@ class AgentExecutor(Chain, BaseModel):
|
||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await tool.arun(
|
||||
output,
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
llm_prefix=llm_prefix,
|
||||
@ -467,7 +473,7 @@ class AgentExecutor(Chain, BaseModel):
|
||||
)
|
||||
else:
|
||||
observation = await InvalidTool().arun(
|
||||
output,
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
llm_prefix="",
|
||||
|
@ -68,7 +68,7 @@ class BaseCallbackHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@ -86,6 +86,10 @@ class BaseCallbackHandler(ABC):
|
||||
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
@ -203,7 +207,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
@ -211,7 +215,16 @@ class CallbackManager(BaseCallbackManager):
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_start(serialized, action, **kwargs)
|
||||
handler.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
@ -293,7 +306,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@ -308,6 +321,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
@ -452,7 +468,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
@ -461,12 +477,12 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_tool_start):
|
||||
await handler.on_tool_start(serialized, action, **kwargs)
|
||||
await handler.on_tool_start(serialized, input_str, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_tool_start, serialized, action, **kwargs
|
||||
handler.on_tool_start, serialized, input_str, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
@ -514,6 +530,23 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
None, functools.partial(handler.on_text, text, **kwargs)
|
||||
)
|
||||
|
||||
async def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_agent_action):
|
||||
await handler.on_agent_action(action, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_agent_action, action, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -58,8 +58,7 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
color: Optional[str] = None,
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
@ -92,6 +91,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
"""Run when agent ends."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -78,11 +78,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
self._callback_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
||||
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
@ -53,11 +53,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
color: Optional[str] = None,
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
print_text(action.log, color=color if color else self.color)
|
||||
|
||||
def on_tool_end(
|
||||
|
@ -41,10 +41,14 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
|
@ -52,10 +52,14 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
st.markdown(action.log.replace("\n", " \n"))
|
||||
|
||||
|
@ -199,7 +199,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._end_trace()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
if self._session is None:
|
||||
@ -209,8 +209,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
tool_run = ToolRun(
|
||||
serialized=serialized,
|
||||
action=action.tool,
|
||||
tool_input=action.tool_input,
|
||||
# TODO: this is duplicate info as above, not needed.
|
||||
action=str(serialized),
|
||||
tool_input=input_str,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
@ -250,6 +251,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Handle an agent finish message."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class Tracer(BaseTracer, ABC):
|
||||
"""A non-thread safe implementation of the BaseTracer interface."""
|
||||
|
@ -7,7 +7,6 @@ from pydantic import BaseModel, Extra, Field, validator
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
|
||||
class BaseTool(BaseModel):
|
||||
@ -45,12 +44,11 @@ class BaseTool(BaseModel):
|
||||
|
||||
def __call__(self, tool_input: str) -> str:
|
||||
"""Make tools callable with str input."""
|
||||
agent_action = AgentAction(tool_input=tool_input, tool=self.name, log="")
|
||||
return self.run(agent_action)
|
||||
return self.run(tool_input)
|
||||
|
||||
def run(
|
||||
self,
|
||||
action: AgentAction,
|
||||
tool_input: str,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
@ -61,13 +59,13 @@ class BaseTool(BaseModel):
|
||||
verbose = self.verbose
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
action,
|
||||
tool_input,
|
||||
verbose=verbose,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
observation = self._run(action.tool_input)
|
||||
observation = self._run(tool_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose)
|
||||
raise e
|
||||
@ -78,7 +76,7 @@ class BaseTool(BaseModel):
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
action: AgentAction,
|
||||
tool_input: str,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
@ -90,7 +88,7 @@ class BaseTool(BaseModel):
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
action,
|
||||
tool_input,
|
||||
verbose=verbose,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
@ -98,14 +96,14 @@ class BaseTool(BaseModel):
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
action,
|
||||
tool_input,
|
||||
verbose=verbose,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await self._arun(action.tool_input)
|
||||
observation = await self._arun(tool_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_error(e, verbose=verbose)
|
||||
|
@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
|
||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == handler.tool_ends == 1
|
||||
assert handler.starts == 6
|
||||
assert handler.tool_starts == 2
|
||||
assert handler.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
|
||||
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == handler.tool_ends == 1
|
||||
assert handler.starts == 6
|
||||
assert handler.tool_starts == 2
|
||||
assert handler.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
|
@ -2,7 +2,6 @@
|
||||
import pytest
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@ -101,7 +100,4 @@ async def test_create_async_tool() -> None:
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert (
|
||||
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
|
||||
== "foo"
|
||||
)
|
||||
assert await test_tool.arun("foo") == "foo"
|
||||
|
@ -94,7 +94,7 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
@ -120,6 +120,11 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
"""Fake async callback handler for testing."""
|
||||
@ -165,7 +170,7 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
@ -190,3 +195,8 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
@ -9,7 +9,7 @@ from langchain.callbacks.base import (
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
@ -27,7 +27,7 @@ def _test_callback_manager(
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
manager.on_tool_start({}, "")
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
@ -44,7 +44,7 @@ async def _test_callback_manager_async(
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
await manager.on_tool_start({}, "")
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
@ -73,7 +73,7 @@ def _test_callback_manager_pass_in_verbose(
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||
@ -136,7 +136,7 @@ def test_ignore_agent() -> None:
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||
|
@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_runs=[
|
||||
@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_end("test")
|
||||
@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_end("test")
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
action="{}",
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
|
Loading…
Reference in New Issue
Block a user