mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
rfc: callback changes (#1165)
conceptually, no reason a tool should know what an "agent action" is unless any objections, can change in all callback handlers
This commit is contained in:
parent
fb83cd4ff4
commit
b7708bbec6
@ -407,6 +407,9 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
return output
|
return output
|
||||||
|
self.callback_manager.on_agent_action(
|
||||||
|
output, verbose=self.verbose, color="green"
|
||||||
|
)
|
||||||
# Otherwise we lookup the tool
|
# Otherwise we lookup the tool
|
||||||
if output.tool in name_to_tool_map:
|
if output.tool in name_to_tool_map:
|
||||||
tool = name_to_tool_map[output.tool]
|
tool = name_to_tool_map[output.tool]
|
||||||
@ -415,7 +418,7 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = tool.run(
|
observation = tool.run(
|
||||||
output,
|
output.tool_input,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
color=color,
|
color=color,
|
||||||
llm_prefix=llm_prefix,
|
llm_prefix=llm_prefix,
|
||||||
@ -423,7 +426,7 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
observation = InvalidTool().run(
|
observation = InvalidTool().run(
|
||||||
output,
|
output.tool_input,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
color=None,
|
color=None,
|
||||||
llm_prefix="",
|
llm_prefix="",
|
||||||
@ -451,6 +454,9 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
return output
|
return output
|
||||||
|
self.callback_manager.on_agent_action(
|
||||||
|
output, verbose=self.verbose, color="green"
|
||||||
|
)
|
||||||
# Otherwise we lookup the tool
|
# Otherwise we lookup the tool
|
||||||
if output.tool in name_to_tool_map:
|
if output.tool in name_to_tool_map:
|
||||||
tool = name_to_tool_map[output.tool]
|
tool = name_to_tool_map[output.tool]
|
||||||
@ -459,7 +465,7 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = await tool.arun(
|
observation = await tool.arun(
|
||||||
output,
|
output.tool_input,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
color=color,
|
color=color,
|
||||||
llm_prefix=llm_prefix,
|
llm_prefix=llm_prefix,
|
||||||
@ -467,7 +473,7 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
observation = await InvalidTool().arun(
|
observation = await InvalidTool().arun(
|
||||||
output,
|
output.tool_input,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
color=None,
|
color=None,
|
||||||
llm_prefix="",
|
llm_prefix="",
|
||||||
|
@ -68,7 +68,7 @@ class BaseCallbackHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
@ -86,6 +86,10 @@ class BaseCallbackHandler(ABC):
|
|||||||
def on_text(self, text: str, **kwargs: Any) -> Any:
|
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||||
"""Run on arbitrary text."""
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
@ -203,7 +207,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
action: AgentAction,
|
input_str: str,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -211,7 +215,16 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
if not handler.ignore_agent:
|
if not handler.ignore_agent:
|
||||||
if verbose or handler.always_verbose:
|
if verbose or handler.always_verbose:
|
||||||
handler.on_tool_start(serialized, action, **kwargs)
|
handler.on_tool_start(serialized, input_str, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_action(
|
||||||
|
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
handler.on_agent_action(action, **kwargs)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
@ -293,7 +306,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
|
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
@ -308,6 +321,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
"""Run on arbitrary text."""
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent action."""
|
||||||
|
|
||||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
|
|
||||||
@ -452,7 +468,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
action: AgentAction,
|
input_str: str,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -461,12 +477,12 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
if not handler.ignore_agent:
|
if not handler.ignore_agent:
|
||||||
if verbose or handler.always_verbose:
|
if verbose or handler.always_verbose:
|
||||||
if asyncio.iscoroutinefunction(handler.on_tool_start):
|
if asyncio.iscoroutinefunction(handler.on_tool_start):
|
||||||
await handler.on_tool_start(serialized, action, **kwargs)
|
await handler.on_tool_start(serialized, input_str, **kwargs)
|
||||||
else:
|
else:
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
handler.on_tool_start, serialized, action, **kwargs
|
handler.on_tool_start, serialized, input_str, **kwargs
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -514,6 +530,23 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
None, functools.partial(handler.on_text, text, **kwargs)
|
None, functools.partial(handler.on_text, text, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_agent_action(
|
||||||
|
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run on agent action."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_agent_action):
|
||||||
|
await handler.on_agent_action(action, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_agent_action, action, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -58,8 +58,7 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
|||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
action: AgentAction,
|
input_str: str,
|
||||||
color: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the log in specified color."""
|
"""Print out the log in specified color."""
|
||||||
@ -92,6 +91,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run when agent ends."""
|
"""Run when agent ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_agent_finish(
|
def on_agent_finish(
|
||||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -78,11 +78,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
self._callback_manager.on_chain_error(error, **kwargs)
|
self._callback_manager.on_chain_error(error, **kwargs)
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_agent_action(action, **kwargs)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
@ -53,11 +53,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
action: AgentAction,
|
input_str: str,
|
||||||
color: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the log in specified color."""
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_action(
|
||||||
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
print_text(action.log, color=color if color else self.color)
|
print_text(action.log, color=color if color else self.color)
|
||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
|
@ -41,10 +41,14 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
|
@ -52,10 +52,14 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
action: AgentAction,
|
input_str: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the log in specified color."""
|
"""Print out the log in specified color."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
# st.write requires two spaces before a newline to render it
|
# st.write requires two spaces before a newline to render it
|
||||||
st.markdown(action.log.replace("\n", " \n"))
|
st.markdown(action.log.replace("\n", " \n"))
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self._end_trace()
|
self._end_trace()
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for a tool run."""
|
"""Start a trace for a tool run."""
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
@ -209,8 +209,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
tool_run = ToolRun(
|
tool_run = ToolRun(
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
action=action.tool,
|
# TODO: this is duplicate info as above, not needed.
|
||||||
tool_input=action.tool_input,
|
action=str(serialized),
|
||||||
|
tool_input=input_str,
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
start_time=datetime.utcnow(),
|
start_time=datetime.utcnow(),
|
||||||
execution_order=self._execution_order,
|
execution_order=self._execution_order,
|
||||||
@ -250,6 +251,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
"""Handle an agent finish message."""
|
"""Handle an agent finish message."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Tracer(BaseTracer, ABC):
|
class Tracer(BaseTracer, ABC):
|
||||||
"""A non-thread safe implementation of the BaseTracer interface."""
|
"""A non-thread safe implementation of the BaseTracer interface."""
|
||||||
|
@ -7,7 +7,6 @@ from pydantic import BaseModel, Extra, Field, validator
|
|||||||
|
|
||||||
from langchain.callbacks import get_callback_manager
|
from langchain.callbacks import get_callback_manager
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.schema import AgentAction
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(BaseModel):
|
class BaseTool(BaseModel):
|
||||||
@ -45,12 +44,11 @@ class BaseTool(BaseModel):
|
|||||||
|
|
||||||
def __call__(self, tool_input: str) -> str:
|
def __call__(self, tool_input: str) -> str:
|
||||||
"""Make tools callable with str input."""
|
"""Make tools callable with str input."""
|
||||||
agent_action = AgentAction(tool_input=tool_input, tool=self.name, log="")
|
return self.run(tool_input)
|
||||||
return self.run(agent_action)
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
action: AgentAction,
|
tool_input: str,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
@ -61,13 +59,13 @@ class BaseTool(BaseModel):
|
|||||||
verbose = self.verbose
|
verbose = self.verbose
|
||||||
self.callback_manager.on_tool_start(
|
self.callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
action,
|
tool_input,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
color=start_color,
|
color=start_color,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
observation = self._run(action.tool_input)
|
observation = self._run(tool_input)
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
self.callback_manager.on_tool_error(e, verbose=verbose)
|
self.callback_manager.on_tool_error(e, verbose=verbose)
|
||||||
raise e
|
raise e
|
||||||
@ -78,7 +76,7 @@ class BaseTool(BaseModel):
|
|||||||
|
|
||||||
async def arun(
|
async def arun(
|
||||||
self,
|
self,
|
||||||
action: AgentAction,
|
tool_input: str,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
@ -90,7 +88,7 @@ class BaseTool(BaseModel):
|
|||||||
if self.callback_manager.is_async:
|
if self.callback_manager.is_async:
|
||||||
await self.callback_manager.on_tool_start(
|
await self.callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
action,
|
tool_input,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
color=start_color,
|
color=start_color,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -98,14 +96,14 @@ class BaseTool(BaseModel):
|
|||||||
else:
|
else:
|
||||||
self.callback_manager.on_tool_start(
|
self.callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
action,
|
tool_input,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
color=start_color,
|
color=start_color,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = await self._arun(action.tool_input)
|
observation = await self._arun(tool_input)
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
if self.callback_manager.is_async:
|
if self.callback_manager.is_async:
|
||||||
await self.callback_manager.on_tool_error(e, verbose=verbose)
|
await self.callback_manager.on_tool_error(e, verbose=verbose)
|
||||||
|
@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
|
|||||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
assert handler.chain_starts == handler.chain_ends == 3
|
assert handler.chain_starts == handler.chain_ends == 3
|
||||||
assert handler.llm_starts == handler.llm_ends == 2
|
assert handler.llm_starts == handler.llm_ends == 2
|
||||||
assert handler.tool_starts == handler.tool_ends == 1
|
assert handler.tool_starts == 2
|
||||||
assert handler.starts == 6
|
assert handler.tool_ends == 1
|
||||||
|
# 1 extra agent action
|
||||||
|
assert handler.starts == 7
|
||||||
# 1 extra agent end
|
# 1 extra agent end
|
||||||
assert handler.ends == 7
|
assert handler.ends == 7
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
|
|||||||
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
||||||
assert handler.chain_starts == handler.chain_ends == 3
|
assert handler.chain_starts == handler.chain_ends == 3
|
||||||
assert handler.llm_starts == handler.llm_ends == 2
|
assert handler.llm_starts == handler.llm_ends == 2
|
||||||
assert handler.tool_starts == handler.tool_ends == 1
|
assert handler.tool_starts == 2
|
||||||
assert handler.starts == 6
|
assert handler.tool_ends == 1
|
||||||
|
# 1 extra agent action
|
||||||
|
assert handler.starts == 7
|
||||||
# 1 extra agent end
|
# 1 extra agent end
|
||||||
assert handler.ends == 7
|
assert handler.ends == 7
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.agents.tools import Tool, tool
|
from langchain.agents.tools import Tool, tool
|
||||||
from langchain.schema import AgentAction
|
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
@ -101,7 +100,4 @@ async def test_create_async_tool() -> None:
|
|||||||
assert test_tool.name == "test_name"
|
assert test_tool.name == "test_name"
|
||||||
assert test_tool.description == "test_description"
|
assert test_tool.description == "test_description"
|
||||||
assert test_tool.coroutine is not None
|
assert test_tool.coroutine is not None
|
||||||
assert (
|
assert await test_tool.arun("foo") == "foo"
|
||||||
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
|
|
||||||
== "foo"
|
|
||||||
)
|
|
||||||
|
@ -94,7 +94,7 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
|||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
self.tool_starts += 1
|
self.tool_starts += 1
|
||||||
@ -120,6 +120,11 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
|||||||
self.agent_ends += 1
|
self.agent_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
|
"""Run on agent action."""
|
||||||
|
self.tool_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
|
||||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||||
"""Fake async callback handler for testing."""
|
"""Fake async callback handler for testing."""
|
||||||
@ -165,7 +170,7 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
|||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
self.tool_starts += 1
|
self.tool_starts += 1
|
||||||
@ -190,3 +195,8 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
|||||||
"""Run when agent ends running."""
|
"""Run when agent ends running."""
|
||||||
self.agent_ends += 1
|
self.agent_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
|
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent action."""
|
||||||
|
self.tool_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
@ -9,7 +9,7 @@ from langchain.callbacks.base import (
|
|||||||
CallbackManager,
|
CallbackManager,
|
||||||
)
|
)
|
||||||
from langchain.callbacks.shared import SharedCallbackManager
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentFinish, LLMResult
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||||
BaseFakeCallbackHandler,
|
BaseFakeCallbackHandler,
|
||||||
FakeAsyncCallbackHandler,
|
FakeAsyncCallbackHandler,
|
||||||
@ -27,7 +27,7 @@ def _test_callback_manager(
|
|||||||
manager.on_chain_start({"name": "foo"}, {})
|
manager.on_chain_start({"name": "foo"}, {})
|
||||||
manager.on_chain_end({})
|
manager.on_chain_end({})
|
||||||
manager.on_chain_error(Exception())
|
manager.on_chain_error(Exception())
|
||||||
manager.on_tool_start({}, AgentAction("", "", ""))
|
manager.on_tool_start({}, "")
|
||||||
manager.on_tool_end("")
|
manager.on_tool_end("")
|
||||||
manager.on_tool_error(Exception())
|
manager.on_tool_error(Exception())
|
||||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||||
@ -44,7 +44,7 @@ async def _test_callback_manager_async(
|
|||||||
await manager.on_chain_start({"name": "foo"}, {})
|
await manager.on_chain_start({"name": "foo"}, {})
|
||||||
await manager.on_chain_end({})
|
await manager.on_chain_end({})
|
||||||
await manager.on_chain_error(Exception())
|
await manager.on_chain_error(Exception())
|
||||||
await manager.on_tool_start({}, AgentAction("", "", ""))
|
await manager.on_tool_start({}, "")
|
||||||
await manager.on_tool_end("")
|
await manager.on_tool_end("")
|
||||||
await manager.on_tool_error(Exception())
|
await manager.on_tool_error(Exception())
|
||||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||||
@ -73,7 +73,7 @@ def _test_callback_manager_pass_in_verbose(
|
|||||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||||
manager.on_chain_end({}, verbose=True)
|
manager.on_chain_end({}, verbose=True)
|
||||||
manager.on_chain_error(Exception(), verbose=True)
|
manager.on_chain_error(Exception(), verbose=True)
|
||||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
manager.on_tool_start({}, "", verbose=True)
|
||||||
manager.on_tool_end("", verbose=True)
|
manager.on_tool_end("", verbose=True)
|
||||||
manager.on_tool_error(Exception(), verbose=True)
|
manager.on_tool_error(Exception(), verbose=True)
|
||||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||||
@ -136,7 +136,7 @@ def test_ignore_agent() -> None:
|
|||||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||||
manager = CallbackManager(handlers=[handler1, handler2])
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
manager.on_tool_start({}, "", verbose=True)
|
||||||
manager.on_tool_end("", verbose=True)
|
manager.on_tool_end("", verbose=True)
|
||||||
manager.on_tool_error(Exception(), verbose=True)
|
manager.on_tool_error(Exception(), verbose=True)
|
||||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||||
|
@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
|
|||||||
TracerSession,
|
TracerSession,
|
||||||
)
|
)
|
||||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||||
from langchain.schema import AgentAction, LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
TEST_SESSION_ID = 2023
|
TEST_SESSION_ID = 2023
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
|||||||
serialized={},
|
serialized={},
|
||||||
tool_input="test",
|
tool_input="test",
|
||||||
output="test",
|
output="test",
|
||||||
action="action",
|
action="{}",
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
error=None,
|
error=None,
|
||||||
child_runs=[
|
child_runs=[
|
||||||
@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
|||||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||||
"""Perform a nested run."""
|
"""Perform a nested run."""
|
||||||
tracer.on_chain_start(serialized={}, inputs={})
|
tracer.on_chain_start(serialized={}, inputs={})
|
||||||
tracer.on_tool_start(
|
tracer.on_tool_start(serialized={}, input_str="test")
|
||||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
||||||
)
|
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult([[]]))
|
tracer.on_llm_end(response=LLMResult([[]]))
|
||||||
tracer.on_tool_end("test")
|
tracer.on_tool_end("test")
|
||||||
@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
|
|||||||
serialized={},
|
serialized={},
|
||||||
tool_input="test",
|
tool_input="test",
|
||||||
output="test",
|
output="test",
|
||||||
action="action",
|
action="{}",
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
error=None,
|
error=None,
|
||||||
)
|
)
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
tracer.on_tool_start(
|
tracer.on_tool_start(serialized={}, input_str="test")
|
||||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
||||||
)
|
|
||||||
tracer.on_tool_end("test")
|
tracer.on_tool_end("test")
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
|
|||||||
serialized={},
|
serialized={},
|
||||||
tool_input="test",
|
tool_input="test",
|
||||||
output=None,
|
output=None,
|
||||||
action="action",
|
action="{}",
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
error=repr(exception),
|
error=repr(exception),
|
||||||
)
|
)
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
tracer.on_tool_start(
|
tracer.on_tool_start(serialized={}, input_str="test")
|
||||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
||||||
)
|
|
||||||
tracer.on_tool_error(exception)
|
tracer.on_tool_error(exception)
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
tracer.on_llm_end(response=LLMResult([[]]))
|
tracer.on_llm_end(response=LLMResult([[]]))
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult([[]]))
|
tracer.on_llm_end(response=LLMResult([[]]))
|
||||||
tracer.on_tool_start(
|
tracer.on_tool_start(serialized={}, input_str="test")
|
||||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
||||||
)
|
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start(serialized={}, prompts=[])
|
||||||
tracer.on_llm_error(exception)
|
tracer.on_llm_error(exception)
|
||||||
tracer.on_tool_error(exception)
|
tracer.on_tool_error(exception)
|
||||||
@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
error=repr(exception),
|
error=repr(exception),
|
||||||
tool_input="test",
|
tool_input="test",
|
||||||
output=None,
|
output=None,
|
||||||
action="action",
|
action="{}",
|
||||||
child_runs=[
|
child_runs=[
|
||||||
LLMRun(
|
LLMRun(
|
||||||
id=None,
|
id=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user