You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/tools/base.py

122 lines
3.9 KiB
Python

"""Base implementation for tools or skills."""
from abc import abstractmethod
from typing import Any, Optional
from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel):
"""Class responsible for defining a tool or skill for an LLM."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, tool_input: str) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run(
self,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
) -> str:
"""Run the tool."""
if verbose is None:
verbose = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
observation = self._run(tool_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose, color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
) -> str:
"""Run the tool asynchronously."""
if verbose is None:
verbose = self.verbose
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose)
else:
self.callback_manager.on_tool_error(e, verbose=verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose, color=color, name=self.name, **kwargs
)
else:
self.callback_manager.on_tool_end(
observation, verbose=verbose, color=color, name=self.name, **kwargs
)
return observation