mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
f238217cea
- **Description:** As Shell tool is very versatile, while integrating it into applications as openai functions, developers have no clue about what command is being executed using the ShellTool. All one can see is: ![image](https://github.com/langchain-ai/langchain/assets/60742358/540e274a-debc-4564-9027-046b91424df3) Summarising my feature request: 1. There's no visibility about what command was executed. 2. There's no mechanism to prevent a command to be executed using ShellTool, like a y/n human input which can be accepted from user to proceed with executing the command., - **Issue:** the issue #15931 it fixes if applicable, - **Dependencies:** There isn't any dependancy, - **Twitter handle:** @krishnashed
103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
import logging
|
|
import platform
|
|
import warnings
|
|
from typing import Any, List, Optional, Type, Union
|
|
|
|
from langchain_core.callbacks import (
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
|
from langchain_core.tools import BaseTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ShellInput(BaseModel):
|
|
"""Commands for the Bash Shell tool."""
|
|
|
|
commands: Union[str, List[str]] = Field(
|
|
...,
|
|
description="List of shell commands to run. Deserialized using json.loads",
|
|
)
|
|
"""List of shell commands to run."""
|
|
|
|
@root_validator
|
|
def _validate_commands(cls, values: dict) -> dict:
|
|
"""Validate commands."""
|
|
# TODO: Add real validators
|
|
commands = values.get("commands")
|
|
if not isinstance(commands, list):
|
|
values["commands"] = [commands]
|
|
# Warn that the bash tool is not safe
|
|
warnings.warn(
|
|
"The shell tool has no safeguards by default. Use at your own risk."
|
|
)
|
|
return values
|
|
|
|
|
|
def _get_default_bash_process() -> Any:
|
|
"""Get default bash process."""
|
|
try:
|
|
from langchain_experimental.llm_bash.bash import BashProcess
|
|
except ImportError:
|
|
raise ImportError(
|
|
"BashProcess has been moved to langchain experimental."
|
|
"To use this tool, install langchain-experimental "
|
|
"with `pip install langchain-experimental`."
|
|
)
|
|
return BashProcess(return_err_output=True)
|
|
|
|
|
|
def _get_platform() -> str:
|
|
"""Get platform."""
|
|
system = platform.system()
|
|
if system == "Darwin":
|
|
return "MacOS"
|
|
return system
|
|
|
|
|
|
class ShellTool(BaseTool):
|
|
"""Tool to run shell commands."""
|
|
|
|
process: Any = Field(default_factory=_get_default_bash_process)
|
|
"""Bash process to run commands."""
|
|
|
|
name: str = "terminal"
|
|
"""Name of tool."""
|
|
|
|
description: str = f"Run shell commands on this {_get_platform()} machine."
|
|
"""Description of tool."""
|
|
|
|
args_schema: Type[BaseModel] = ShellInput
|
|
"""Schema for input arguments."""
|
|
|
|
ask_human_input: bool = False
|
|
"""
|
|
If True, prompts the user for confirmation (y/n) before executing
|
|
a command generated by the language model in the bash shell.
|
|
"""
|
|
|
|
def _run(
|
|
self,
|
|
commands: Union[str, List[str]],
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Run commands and return final output."""
|
|
|
|
print(f"Executing command:\n {commands}")
|
|
|
|
try:
|
|
if self.ask_human_input:
|
|
user_input = input("Proceed with command execution? (y/n): ").lower()
|
|
if user_input == "y":
|
|
return self.process.run(commands)
|
|
else:
|
|
logger.info("Invalid input. User aborted command execution.")
|
|
return None
|
|
else:
|
|
return self.process.run(commands)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during command execution: {e}")
|
|
return None
|