mirror of https://github.com/hwchase17/langchain
Add Shell Tool (#3335)
Create an official bash shell tool to replace the dynamically generated onepull/3728/head
parent
334c162f16
commit
5042bd40d3
@ -0,0 +1,5 @@
|
|||||||
|
"""Shell tool."""
|
||||||
|
|
||||||
|
from langchain.tools.shell.tool import ShellTool
|
||||||
|
|
||||||
|
__all__ = ["ShellTool"]
|
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
import platform
|
||||||
|
import warnings
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
|
|
||||||
|
class ShellInput(BaseModel):
|
||||||
|
"""Commands for the Bash Shell tool."""
|
||||||
|
|
||||||
|
commands: List[str] = Field(
|
||||||
|
...,
|
||||||
|
description="List of shell commands to run. Deserialized using json.loads",
|
||||||
|
)
|
||||||
|
"""List of shell commands to run."""
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def _validate_commands(cls, values: dict) -> dict:
|
||||||
|
"""Validate commands."""
|
||||||
|
# TODO: Add real validators
|
||||||
|
commands = values.get("commands")
|
||||||
|
if not isinstance(commands, list):
|
||||||
|
values["commands"] = [commands]
|
||||||
|
# Warn that the bash tool is not safe
|
||||||
|
warnings.warn(
|
||||||
|
"The shell tool has no safeguards by default. Use at your own risk."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_bash_processs() -> BashProcess:
|
||||||
|
"""Get file path from string."""
|
||||||
|
return BashProcess(return_err_output=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_platform() -> str:
|
||||||
|
"""Get platform."""
|
||||||
|
system = platform.system()
|
||||||
|
if system == "Darwin":
|
||||||
|
return "MacOS"
|
||||||
|
return system
|
||||||
|
|
||||||
|
|
||||||
|
class ShellTool(BaseTool):
|
||||||
|
"""Tool to run shell commands."""
|
||||||
|
|
||||||
|
process: BashProcess = Field(default_factory=_get_default_bash_processs)
|
||||||
|
"""Bash process to run commands."""
|
||||||
|
|
||||||
|
name: str = "terminal"
|
||||||
|
"""Name of tool."""
|
||||||
|
|
||||||
|
description: str = f"Run shell commands on this {_get_platform()} machine."
|
||||||
|
"""Description of tool."""
|
||||||
|
|
||||||
|
args_schema: Type[BaseModel] = ShellInput
|
||||||
|
"""Schema for input arguments."""
|
||||||
|
|
||||||
|
def _run(self, commands: List[str]) -> str:
|
||||||
|
"""Run commands and return final output."""
|
||||||
|
return self.process.run(commands)
|
||||||
|
|
||||||
|
async def _arun(self, commands: List[str]) -> str:
|
||||||
|
"""Run commands asynchronously and return final output."""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.process.run, commands
|
||||||
|
)
|
@ -0,0 +1,43 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.tools.shell.tool import ShellInput, ShellTool
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
test_commands = ["echo 'Hello, World!'", "echo 'Another command'"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_input_validation() -> None:
|
||||||
|
shell_input = ShellInput(commands=test_commands)
|
||||||
|
assert isinstance(shell_input.commands, list)
|
||||||
|
assert len(shell_input.commands) == 2
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
ShellInput(commands=test_commands)
|
||||||
|
assert len(w) == 1
|
||||||
|
assert (
|
||||||
|
str(w[-1].message)
|
||||||
|
== "The shell tool has no safeguards by default. Use at your own risk."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_tool_init() -> None:
|
||||||
|
shell_tool = ShellTool()
|
||||||
|
assert shell_tool.name == "terminal"
|
||||||
|
assert isinstance(shell_tool.description, str)
|
||||||
|
assert shell_tool.args_schema == ShellInput
|
||||||
|
assert shell_tool.process is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shell_tool_arun() -> None:
|
||||||
|
shell_tool = ShellTool()
|
||||||
|
result = await shell_tool._arun(commands=test_commands)
|
||||||
|
assert result.strip() == "Hello, World!\nAnother command"
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_tool_run() -> None:
|
||||||
|
shell_tool = ShellTool()
|
||||||
|
result = shell_tool._run(commands=test_commands)
|
||||||
|
assert result.strip() == "Hello, World!\nAnother command"
|
Loading…
Reference in New Issue