Accept str or list[str] for shell (#4060)

Relax the requirements
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 5a269d3175
commit 65c3b146c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
import asyncio import asyncio
import platform import platform
import warnings import warnings
from typing import List, Optional, Type from typing import List, Optional, Type, Union
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, root_validator
@ -16,7 +16,7 @@ from langchain.utilities.bash import BashProcess
class ShellInput(BaseModel): class ShellInput(BaseModel):
"""Commands for the Bash Shell tool.""" """Commands for the Bash Shell tool."""
commands: List[str] = Field( commands: Union[str, List[str]] = Field(
..., ...,
description="List of shell commands to run. Deserialized using json.loads", description="List of shell commands to run. Deserialized using json.loads",
) )
@ -66,7 +66,7 @@ class ShellTool(BaseTool):
def _run( def _run(
self, self,
commands: List[str], commands: Union[str, List[str]],
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str: ) -> str:
"""Run commands and return final output.""" """Run commands and return final output."""
@ -74,7 +74,7 @@ class ShellTool(BaseTool):
async def _arun( async def _arun(
self, self,
commands: List[str], commands: Union[str, List[str]],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> str:
"""Run commands asynchronously and return final output.""" """Run commands asynchronously and return final output."""

@ -30,6 +30,12 @@ def test_shell_tool_init() -> None:
assert shell_tool.process is not None assert shell_tool.process is not None
def test_shell_tool_run() -> None:
shell_tool = ShellTool()
result = shell_tool._run(commands=test_commands)
assert result.strip() == "Hello, World!\nAnother command"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shell_tool_arun() -> None: async def test_shell_tool_arun() -> None:
shell_tool = ShellTool() shell_tool = ShellTool()
@ -37,7 +43,7 @@ async def test_shell_tool_arun() -> None:
assert result.strip() == "Hello, World!\nAnother command" assert result.strip() == "Hello, World!\nAnother command"
def test_shell_tool_run() -> None: def test_shell_tool_run_str() -> None:
shell_tool = ShellTool() shell_tool = ShellTool()
result = shell_tool._run(commands=test_commands) result = shell_tool._run(commands="echo 'Hello, World!'")
assert result.strip() == "Hello, World!\nAnother command" assert result.strip() == "Hello, World!"

Loading…
Cancel
Save