From 65c3b146c971baafe1e3a67b16caddbf705b75d5 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Wed, 3 May 2023 21:11:06 -0700 Subject: [PATCH] Accept str or list[str] for shell (#4060) Relax the requirements --- langchain/tools/shell/tool.py | 8 ++++---- tests/unit_tests/tools/shell/test_shell.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/langchain/tools/shell/tool.py b/langchain/tools/shell/tool.py index 42e19038..28f2edcf 100644 --- a/langchain/tools/shell/tool.py +++ b/langchain/tools/shell/tool.py @@ -1,7 +1,7 @@ import asyncio import platform import warnings -from typing import List, Optional, Type +from typing import List, Optional, Type, Union from pydantic import BaseModel, Field, root_validator @@ -16,7 +16,7 @@ from langchain.utilities.bash import BashProcess class ShellInput(BaseModel): """Commands for the Bash Shell tool.""" - commands: List[str] = Field( + commands: Union[str, List[str]] = Field( ..., description="List of shell commands to run. Deserialized using json.loads", ) @@ -66,7 +66,7 @@ class ShellTool(BaseTool): def _run( self, - commands: List[str], + commands: Union[str, List[str]], run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Run commands and return final output.""" @@ -74,7 +74,7 @@ class ShellTool(BaseTool): async def _arun( self, - commands: List[str], + commands: Union[str, List[str]], run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Run commands asynchronously and return final output.""" diff --git a/tests/unit_tests/tools/shell/test_shell.py b/tests/unit_tests/tools/shell/test_shell.py index 48ab3f37..c4442030 100644 --- a/tests/unit_tests/tools/shell/test_shell.py +++ b/tests/unit_tests/tools/shell/test_shell.py @@ -30,6 +30,12 @@ def test_shell_tool_init() -> None: assert shell_tool.process is not None +def test_shell_tool_run() -> None: + shell_tool = ShellTool() + result = shell_tool._run(commands=test_commands) + assert result.strip() == "Hello, World!\nAnother command" + + @pytest.mark.asyncio async def test_shell_tool_arun() -> None: shell_tool = ShellTool() @@ -37,7 +43,7 @@ async def test_shell_tool_arun() -> None: assert result.strip() == "Hello, World!\nAnother command" -def test_shell_tool_run() -> None: +def test_shell_tool_run_str() -> None: shell_tool = ShellTool() - result = shell_tool._run(commands=test_commands) - assert result.strip() == "Hello, World!\nAnother command" + result = shell_tool._run(commands="echo 'Hello, World!'") + assert result.strip() == "Hello, World!"