mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
90 lines
2.6 KiB
Python
90 lines
2.6 KiB
Python
|
import asyncio
|
||
|
import platform
|
||
|
import warnings
|
||
|
from typing import Any, List, Optional, Type, Union
|
||
|
|
||
|
from langchain_core.callbacks import (
|
||
|
AsyncCallbackManagerForToolRun,
|
||
|
CallbackManagerForToolRun,
|
||
|
)
|
||
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||
|
from langchain_core.tools import BaseTool
|
||
|
|
||
|
|
||
|
class ShellInput(BaseModel):
|
||
|
"""Commands for the Bash Shell tool."""
|
||
|
|
||
|
commands: Union[str, List[str]] = Field(
|
||
|
...,
|
||
|
description="List of shell commands to run. Deserialized using json.loads",
|
||
|
)
|
||
|
"""List of shell commands to run."""
|
||
|
|
||
|
@root_validator
|
||
|
def _validate_commands(cls, values: dict) -> dict:
|
||
|
"""Validate commands."""
|
||
|
# TODO: Add real validators
|
||
|
commands = values.get("commands")
|
||
|
if not isinstance(commands, list):
|
||
|
values["commands"] = [commands]
|
||
|
# Warn that the bash tool is not safe
|
||
|
warnings.warn(
|
||
|
"The shell tool has no safeguards by default. Use at your own risk."
|
||
|
)
|
||
|
return values
|
||
|
|
||
|
|
||
|
def _get_default_bash_process() -> Any:
|
||
|
"""Get default bash process."""
|
||
|
try:
|
||
|
from langchain_experimental.llm_bash.bash import BashProcess
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"BashProcess has been moved to langchain experimental."
|
||
|
"To use this tool, install langchain-experimental "
|
||
|
"with `pip install langchain-experimental`."
|
||
|
)
|
||
|
return BashProcess(return_err_output=True)
|
||
|
|
||
|
|
||
|
def _get_platform() -> str:
|
||
|
"""Get platform."""
|
||
|
system = platform.system()
|
||
|
if system == "Darwin":
|
||
|
return "MacOS"
|
||
|
return system
|
||
|
|
||
|
|
||
|
class ShellTool(BaseTool):
|
||
|
"""Tool to run shell commands."""
|
||
|
|
||
|
process: Any = Field(default_factory=_get_default_bash_process)
|
||
|
"""Bash process to run commands."""
|
||
|
|
||
|
name: str = "terminal"
|
||
|
"""Name of tool."""
|
||
|
|
||
|
description: str = f"Run shell commands on this {_get_platform()} machine."
|
||
|
"""Description of tool."""
|
||
|
|
||
|
args_schema: Type[BaseModel] = ShellInput
|
||
|
"""Schema for input arguments."""
|
||
|
|
||
|
def _run(
|
||
|
self,
|
||
|
commands: Union[str, List[str]],
|
||
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||
|
) -> str:
|
||
|
"""Run commands and return final output."""
|
||
|
return self.process.run(commands)
|
||
|
|
||
|
async def _arun(
|
||
|
self,
|
||
|
commands: Union[str, List[str]],
|
||
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||
|
) -> str:
|
||
|
"""Run commands asynchronously and return final output."""
|
||
|
return await asyncio.get_event_loop().run_in_executor(
|
||
|
None, self.process.run, commands
|
||
|
)
|