diff --git a/docs/modules/agents/tools/examples/filesystem.ipynb b/docs/modules/agents/tools/examples/filesystem.ipynb new file mode 100644 index 00000000..61815baa --- /dev/null +++ b/docs/modules/agents/tools/examples/filesystem.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# File System Tools\n", + "\n", + "LangChain provides tools for interacting with a local file system out of the box. This notebook walks through some of them.\n", + "\n", + "Note: these tools are not recommended for use outside a sandboxed environment! " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we'll import the tools." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.tools.file_management import (\n", + " ReadFileTool,\n", + " CopyFileTool,\n", + " DeleteFileTool,\n", + " MoveFileTool,\n", + " WriteFileTool,\n", + " ListDirectoryTool,\n", + ")\n", + "from langchain.agents.agent_toolkits import FileManagementToolkit\n", + "from tempfile import TemporaryDirectory\n", + "\n", + "# We'll make a temporary directory to avoid clutter\n", + "working_directory = TemporaryDirectory()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The FileManagementToolkit\n", + "\n", + "If you want to provide all the file tooling to your agent, it's easy to do so with the toolkit. We'll pass the temporary directory in as a root directory as a workspace for the LLM.\n", + "\n", + "It's recommended to always pass in a root directory, since without one, it's easy for the LLM to pollute the working directory, and without one, there isn't any validation against\n", + "straightforward prompt injection." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[CopyFileTool(name='copy_file', description='Create a copy of a file in a specified location', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " DeleteFileTool(name='file_delete', description='Delete a file', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " FileSearchTool(name='file_search', description='Recursively search for files in a subdirectory that match the regex pattern', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " MoveFileTool(name='move_file', description='Move or rename a file from one location to another', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " ReadFileTool(name='read_file', description='Read file from disk', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " WriteFileTool(name='write_file', description='Write file to disk', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " ListDirectoryTool(name='list_directory', description='List files and directories in a specified folder', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug')]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "toolkit = FileManagementToolkit(root_dir=str(working_directory.name)) # If you don't provide a root_dir, operations will default to the current working directory\n", + "toolkit.get_tools()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Selecting File System Tools\n", + "\n", + "If you only want to select certain tools, you can pass them in as arguments when initializing the toolkit, or you can individually initialize the desired tools." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[ReadFileTool(name='read_file', description='Read file from disk', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " WriteFileTool(name='write_file', description='Write file to disk', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n", + " ListDirectoryTool(name='list_directory', description='List files and directories in a specified folder', args_schema=, return_direct=False, verbose=False, callback_manager=, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug')]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tools = FileManagementToolkit(root_dir=str(working_directory.name), selected_tools=[\"read_file\", \"write_file\", \"list_directory\"]).get_tools()\n", + "tools" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'File written successfully to example.txt.'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "read_tool, write_tool, list_tool = tools\n", + "write_tool.run({\"file_path\": \"example.txt\", \"text\": \"Hello World!\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'example.txt'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# List files in the working directory\n", + "list_tool.run({})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/agents/agent_toolkits/__init__.py b/langchain/agents/agent_toolkits/__init__.py index 54f1a2e6..4a0bbd1c 100644 --- a/langchain/agents/agent_toolkits/__init__.py +++ b/langchain/agents/agent_toolkits/__init__.py @@ -1,6 +1,9 @@ """Agent toolkits.""" from langchain.agents.agent_toolkits.csv.base import create_csv_agent +from langchain.agents.agent_toolkits.file_management.toolkit import ( + FileManagementToolkit, +) from langchain.agents.agent_toolkits.jira.toolkit import JiraToolkit from langchain.agents.agent_toolkits.json.base import create_json_agent from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit @@ -47,5 +50,6 @@ __all__ = [ "create_csv_agent", "ZapierToolkit", "JiraToolkit", + "FileManagementToolkit", "PlayWrightBrowserToolkit", ] diff --git a/langchain/agents/agent_toolkits/file_management/__init__.py b/langchain/agents/agent_toolkits/file_management/__init__.py new file mode 100644 index 00000000..d245816a --- /dev/null +++ b/langchain/agents/agent_toolkits/file_management/__init__.py @@ -0,0 +1,7 @@ +"""Local file management toolkit.""" + +from langchain.agents.agent_toolkits.file_management.toolkit import ( + FileManagementToolkit, +) + +__all__ = ["FileManagementToolkit"] diff --git a/langchain/agents/agent_toolkits/file_management/toolkit.py b/langchain/agents/agent_toolkits/file_management/toolkit.py new file mode 100644 index 00000000..17ae4f3a --- /dev/null +++ b/langchain/agents/agent_toolkits/file_management/toolkit.py @@ -0,0 +1,61 @@ +"""Toolkit for interacting with the local filesystem.""" +from __future__ import annotations + +from typing import List, Optional + +from pydantic import root_validator + +from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.tools import BaseTool +from langchain.tools.file_management.copy import CopyFileTool +from langchain.tools.file_management.delete import DeleteFileTool +from langchain.tools.file_management.file_search import FileSearchTool +from langchain.tools.file_management.list_dir import ListDirectoryTool +from langchain.tools.file_management.move import MoveFileTool +from langchain.tools.file_management.read import ReadFileTool +from langchain.tools.file_management.write import WriteFileTool + +_FILE_TOOLS = { + tool_cls.__fields__["name"].default: tool_cls + for tool_cls in [ + CopyFileTool, + DeleteFileTool, + FileSearchTool, + MoveFileTool, + ReadFileTool, + WriteFileTool, + ListDirectoryTool, + ] +} + + +class FileManagementToolkit(BaseToolkit): + """Toolkit for interacting with a Local Files.""" + + root_dir: Optional[str] = None + """If specified, all file operations are made relative to root_dir.""" + selected_tools: Optional[List[str]] = None + """If provided, only provide the selected tools. Defaults to all.""" + + @root_validator + def validate_tools(cls, values: dict) -> dict: + selected_tools = values.get("selected_tools") or [] + for tool_name in selected_tools: + if tool_name not in _FILE_TOOLS: + raise ValueError( + f"File Tool of name {tool_name} not supported." + f" Permitted tools: {list(_FILE_TOOLS)}" + ) + return values + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + allowed_tools = self.selected_tools or _FILE_TOOLS.keys() + tools: List[BaseTool] = [] + for tool in allowed_tools: + tool_cls = _FILE_TOOLS[tool] + tools.append(tool_cls(root_dir=self.root_dir)) + return tools + + +__all__ = ["FileManagementToolkit"] diff --git a/langchain/tools/__init__.py b/langchain/tools/__init__.py index 3cf1dad2..a95590cd 100644 --- a/langchain/tools/__init__.py +++ b/langchain/tools/__init__.py @@ -3,6 +3,13 @@ from langchain.tools.base import BaseTool from langchain.tools.bing_search.tool import BingSearchResults, BingSearchRun from langchain.tools.ddg_search.tool import DuckDuckGoSearchResults, DuckDuckGoSearchRun +from langchain.tools.file_management.copy import CopyFileTool +from langchain.tools.file_management.delete import DeleteFileTool +from langchain.tools.file_management.file_search import FileSearchTool +from langchain.tools.file_management.list_dir import ListDirectoryTool +from langchain.tools.file_management.move import MoveFileTool +from langchain.tools.file_management.read import ReadFileTool +from langchain.tools.file_management.write import WriteFileTool from langchain.tools.google_places.tool import GooglePlacesTool from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun from langchain.tools.ifttt import IFTTTWebhook @@ -21,25 +28,33 @@ from langchain.tools.playwright import ( from langchain.tools.plugin import AIPluginTool __all__ = [ - "APIOperation", + "AIPluginTool", "BaseBrowserTool", "BaseTool", + "BaseTool", "BingSearchResults", "BingSearchRun", "ClickTool", + "CopyFileTool", "CurrentWebPageTool", + "DeleteFileTool", "DuckDuckGoSearchResults", "DuckDuckGoSearchRun", "DuckDuckGoSearchRun", "ExtractHyperlinksTool", "ExtractTextTool", + "FileSearchTool", "GetElementsTool", "GooglePlacesTool", "GoogleSearchResults", "GoogleSearchRun", "IFTTTWebhook", + "ListDirectoryTool", + "MoveFileTool", "NavigateBackTool", "NavigateTool", "OpenAPISpec", - "AIPluginTool", + "ReadFileTool", + "WriteFileTool", + "APIOperation", ] diff --git a/langchain/tools/file_management/__init__.py b/langchain/tools/file_management/__init__.py index e69de29b..4b2b6c0d 100644 --- a/langchain/tools/file_management/__init__.py +++ b/langchain/tools/file_management/__init__.py @@ -0,0 +1,19 @@ +"""File Management Tools.""" + +from langchain.tools.file_management.copy import CopyFileTool +from langchain.tools.file_management.delete import DeleteFileTool +from langchain.tools.file_management.file_search import FileSearchTool +from langchain.tools.file_management.list_dir import ListDirectoryTool +from langchain.tools.file_management.move import MoveFileTool +from langchain.tools.file_management.read import ReadFileTool +from langchain.tools.file_management.write import WriteFileTool + +__all__ = [ + "CopyFileTool", + "DeleteFileTool", + "FileSearchTool", + "MoveFileTool", + "ReadFileTool", + "WriteFileTool", + "ListDirectoryTool", +] diff --git a/langchain/tools/file_management/copy.py b/langchain/tools/file_management/copy.py new file mode 100644 index 00000000..0fb23c7e --- /dev/null +++ b/langchain/tools/file_management/copy.py @@ -0,0 +1,46 @@ +import shutil +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) + + +class FileCopyInput(BaseModel): + """Input for CopyFileTool.""" + + source_path: str = Field(..., description="Path of the file to copy") + destination_path: str = Field(..., description="Path to save the copied file") + + +class CopyFileTool(BaseFileTool): + name: str = "copy_file" + args_schema: Type[BaseModel] = FileCopyInput + description: str = "Create a copy of a file in a specified location" + + def _run(self, source_path: str, destination_path: str) -> str: + try: + source_path_ = self.get_relative_path(source_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format( + arg_name="source_path", value=source_path + ) + try: + destination_path_ = self.get_relative_path(destination_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format( + arg_name="destination_path", value=destination_path + ) + try: + shutil.copy2(source_path_, destination_path_, follow_symlinks=False) + return f"File copied successfully from {source_path} to {destination_path}." + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, source_path: str, destination_path: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/langchain/tools/file_management/delete.py b/langchain/tools/file_management/delete.py new file mode 100644 index 00000000..218cf606 --- /dev/null +++ b/langchain/tools/file_management/delete.py @@ -0,0 +1,39 @@ +import os +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) + + +class FileDeleteInput(BaseModel): + """Input for DeleteFileTool.""" + + file_path: str = Field(..., description="Path of the file to delete") + + +class DeleteFileTool(BaseFileTool): + name: str = "file_delete" + args_schema: Type[BaseModel] = FileDeleteInput + description: str = "Delete a file" + + def _run(self, file_path: str) -> str: + try: + file_path_ = self.get_relative_path(file_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format(arg_name="file_path", value=file_path) + if not file_path_.exists(): + return f"Error: no such file or directory: {file_path}" + try: + os.remove(file_path_) + return f"File deleted successfully: {file_path}." + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, file_path: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/langchain/tools/file_management/file_search.py b/langchain/tools/file_management/file_search.py new file mode 100644 index 00000000..7e2f1d93 --- /dev/null +++ b/langchain/tools/file_management/file_search.py @@ -0,0 +1,55 @@ +import fnmatch +import os +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) + + +class FileSearchInput(BaseModel): + """Input for FileSearchTool.""" + + dir_path: str = Field( + default=".", + description="Subdirectory to search in.", + ) + pattern: str = Field( + ..., + description="Unix shell regex, where * matches everything.", + ) + + +class FileSearchTool(BaseFileTool): + name: str = "file_search" + args_schema: Type[BaseModel] = FileSearchInput + description: str = ( + "Recursively search for files in a subdirectory that match the regex pattern" + ) + + def _run(self, pattern: str, dir_path: str = ".") -> str: + try: + dir_path_ = self.get_relative_path(dir_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value=dir_path) + matches = [] + try: + for root, _, filenames in os.walk(dir_path_): + for filename in fnmatch.filter(filenames, pattern): + absolute_path = os.path.join(root, filename) + relative_path = os.path.relpath(absolute_path, dir_path_) + matches.append(relative_path) + if matches: + return "\n".join(matches) + else: + return f"No files found for pattern {pattern} in directory {dir_path}" + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, dir_path: str, pattern: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/langchain/tools/file_management/list_dir.py b/langchain/tools/file_management/list_dir.py new file mode 100644 index 00000000..ff5cb8a1 --- /dev/null +++ b/langchain/tools/file_management/list_dir.py @@ -0,0 +1,40 @@ +import os +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) + + +class DirectoryListingInput(BaseModel): + """Input for ListDirectoryTool.""" + + dir_path: str = Field(default=".", description="Subdirectory to list.") + + +class ListDirectoryTool(BaseFileTool): + name: str = "list_directory" + args_schema: Type[BaseModel] = DirectoryListingInput + description: str = "List files and directories in a specified folder" + + def _run(self, dir_path: str = ".") -> str: + try: + dir_path_ = self.get_relative_path(dir_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value=dir_path) + try: + entries = os.listdir(dir_path_) + if entries: + return "\n".join(entries) + else: + return f"No files found in directory {dir_path}" + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, dir_path: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/langchain/tools/file_management/move.py b/langchain/tools/file_management/move.py new file mode 100644 index 00000000..ccf88796 --- /dev/null +++ b/langchain/tools/file_management/move.py @@ -0,0 +1,49 @@ +import shutil +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) + + +class FileMoveInput(BaseModel): + """Input for MoveFileTool.""" + + source_path: str = Field(..., description="Path of the file to move") + destination_path: str = Field(..., description="New path for the moved file") + + +class MoveFileTool(BaseFileTool): + name: str = "move_file" + args_schema: Type[BaseModel] = FileMoveInput + description: str = "Move or rename a file from one location to another" + + def _run(self, source_path: str, destination_path: str) -> str: + try: + source_path_ = self.get_relative_path(source_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format( + arg_name="source_path", value=source_path + ) + try: + destination_path_ = self.get_relative_path(destination_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format( + arg_name="destination_path_", value=destination_path_ + ) + if not source_path_.exists(): + return f"Error: no such file or directory {source_path}" + try: + # shutil.move expects str args in 3.8 + shutil.move(str(source_path_), destination_path_) + return f"File moved successfully from {source_path} to {destination_path}." + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, source_path: str, destination_path: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/langchain/tools/file_management/read.py b/langchain/tools/file_management/read.py index 13d7dcb0..d243a9e3 100644 --- a/langchain/tools/file_management/read.py +++ b/langchain/tools/file_management/read.py @@ -1,10 +1,12 @@ -from pathlib import Path -from typing import Optional, Type +from typing import Type from pydantic import BaseModel, Field -from langchain.tools.base import BaseTool -from langchain.tools.file_management.utils import get_validated_relative_path +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) class ReadFileInput(BaseModel): @@ -13,21 +15,18 @@ class ReadFileInput(BaseModel): file_path: str = Field(..., description="name of file") -class ReadFileTool(BaseTool): +class ReadFileTool(BaseFileTool): name: str = "read_file" args_schema: Type[BaseModel] = ReadFileInput description: str = "Read file from disk" - root_dir: Optional[str] = None - """Directory to read file from. - - If specified, raises an error for file_paths oustide root_dir.""" def _run(self, file_path: str) -> str: - read_path = ( - get_validated_relative_path(Path(self.root_dir), file_path) - if self.root_dir - else Path(file_path) - ) + try: + read_path = self.get_relative_path(file_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format(arg_name="file_path", value=file_path) + if not read_path.exists(): + return f"Error: no such file or directory: {file_path}" try: with read_path.open("r", encoding="utf-8") as f: content = f.read() @@ -35,6 +34,6 @@ class ReadFileTool(BaseTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, tool_input: str) -> str: + async def _arun(self, file_path: str) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/utils.py b/langchain/tools/file_management/utils.py index 6d3bb4c1..c8efefb4 100644 --- a/langchain/tools/file_management/utils.py +++ b/langchain/tools/file_management/utils.py @@ -1,5 +1,10 @@ import sys from pathlib import Path +from typing import Any, Optional + +from pydantic import BaseModel + +from langchain.tools.base import BaseTool def is_relative_to(path: Path, root: Path) -> bool: @@ -14,6 +19,35 @@ def is_relative_to(path: Path, root: Path) -> bool: return False +INVALID_PATH_TEMPLATE = ( + "Error: Access denied to {arg_name}: {value}." + " Permission granted exclusively to the current working directory" +) + + +class FileValidationError(ValueError): + """Error for paths outside the root directory.""" + + +class BaseFileTool(BaseTool, BaseModel): + """Input for ReadFileTool.""" + + root_dir: Optional[str] = None + """The final path will be chosen relative to root_dir if specified.""" + + def get_relative_path(self, file_path: str) -> Path: + """Get the relative path, returning an error if unsupported.""" + if self.root_dir is None: + return Path(file_path) + return get_validated_relative_path(Path(self.root_dir), file_path) + + def _run(self, *args: Any, **kwargs: Any) -> str: + raise NotImplementedError + + async def _arun(self, *args: Any, **kwargs: Any) -> str: + raise NotImplementedError + + def get_validated_relative_path(root: Path, user_path: str) -> Path: """Resolve a relative path, raising an error if not within the root directory.""" # Note, this still permits symlinks from outside that point within the root. @@ -22,5 +56,7 @@ def get_validated_relative_path(root: Path, user_path: str) -> Path: full_path = (root / user_path).resolve() if not is_relative_to(full_path, root): - raise ValueError(f"Path {user_path} is outside of the allowed directory {root}") + raise FileValidationError( + f"Path {user_path} is outside of the allowed directory {root}" + ) return full_path diff --git a/langchain/tools/file_management/write.py b/langchain/tools/file_management/write.py index ae4093b2..865bcbe7 100644 --- a/langchain/tools/file_management/write.py +++ b/langchain/tools/file_management/write.py @@ -1,10 +1,12 @@ -from pathlib import Path -from typing import Optional, Type +from typing import Type from pydantic import BaseModel, Field -from langchain.tools.base import BaseTool -from langchain.tools.file_management.utils import get_validated_relative_path +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, + BaseFileTool, + FileValidationError, +) class WriteFileInput(BaseModel): @@ -12,31 +14,30 @@ class WriteFileInput(BaseModel): file_path: str = Field(..., description="name of file") text: str = Field(..., description="text to write to file") + append: bool = Field( + default=False, description="Whether to append to an existing file." + ) -class WriteFileTool(BaseTool): +class WriteFileTool(BaseFileTool): name: str = "write_file" args_schema: Type[BaseModel] = WriteFileInput description: str = "Write file to disk" - root_dir: Optional[str] = None - """Directory to write file to. - If specified, raises an error for file_paths oustide root_dir.""" - - def _run(self, file_path: str, text: str) -> str: - write_path = ( - get_validated_relative_path(Path(self.root_dir), file_path) - if self.root_dir - else Path(file_path) - ) + def _run(self, file_path: str, text: str, append: bool = False) -> str: + try: + write_path = self.get_relative_path(file_path) + except FileValidationError: + return INVALID_PATH_TEMPLATE.format(arg_name="file_path", value=file_path) try: write_path.parent.mkdir(exist_ok=True, parents=False) - with write_path.open("w", encoding="utf-8") as f: + mode = "a" if append else "w" + with write_path.open(mode, encoding="utf-8") as f: f.write(text) return f"File written successfully to {file_path}." except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str, text: str) -> str: + async def _arun(self, file_path: str, text: str, append: bool = False) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/tests/unit_tests/tools/file_management/test_copy.py b/tests/unit_tests/tools/file_management/test_copy.py new file mode 100644 index 00000000..591e06c1 --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_copy.py @@ -0,0 +1,54 @@ +"""Test the FileCopy tool.""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +from langchain.tools.file_management.copy import CopyFileTool +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, +) + + +def test_copy_file_with_root_dir() -> None: + """Test the FileCopy tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = CopyFileTool(root_dir=temp_dir) + source_file = Path(temp_dir) / "source.txt" + destination_file = Path(temp_dir) / "destination.txt" + source_file.write_text("Hello, world!") + tool.run({"source_path": "source.txt", "destination_path": "destination.txt"}) + assert source_file.exists() + assert destination_file.exists() + assert source_file.read_text() == "Hello, world!" + assert destination_file.read_text() == "Hello, world!" + + +def test_copy_file_errs_outside_root_dir() -> None: + """Test the FileCopy tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = CopyFileTool(root_dir=temp_dir) + result = tool.run( + { + "source_path": "../source.txt", + "destination_path": "../destination.txt", + } + ) + assert result == INVALID_PATH_TEMPLATE.format( + arg_name="source_path", value="../source.txt" + ) + + +def test_copy_file() -> None: + """Test the FileCopy tool.""" + with TemporaryDirectory() as temp_dir: + tool = CopyFileTool() + source_file = Path(temp_dir) / "source.txt" + destination_file = Path(temp_dir) / "destination.txt" + source_file.write_text("Hello, world!") + tool.run( + {"source_path": str(source_file), "destination_path": str(destination_file)} + ) + assert source_file.exists() + assert destination_file.exists() + assert source_file.read_text() == "Hello, world!" + assert destination_file.read_text() == "Hello, world!" diff --git a/tests/unit_tests/tools/file_management/test_file_search.py b/tests/unit_tests/tools/file_management/test_file_search.py new file mode 100644 index 00000000..bb10b5cc --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_file_search.py @@ -0,0 +1,43 @@ +"""Test the FileSearch tool.""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +from langchain.tools.file_management.file_search import FileSearchTool +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, +) + + +def test_file_search_with_root_dir() -> None: + """Test the FileSearch tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = FileSearchTool(root_dir=temp_dir) + file_1 = Path(temp_dir) / "file1.txt" + file_2 = Path(temp_dir) / "file2.log" + file_1.write_text("File 1 content") + file_2.write_text("File 2 content") + matches = tool.run({"dir_path": ".", "pattern": "*.txt"}).split("\n") + assert len(matches) == 1 + assert Path(matches[0]).name == "file1.txt" + + +def test_file_search_errs_outside_root_dir() -> None: + """Test the FileSearch tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = FileSearchTool(root_dir=temp_dir) + result = tool.run({"dir_path": "..", "pattern": "*.txt"}) + assert result == INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value="..") + + +def test_file_search() -> None: + """Test the FileSearch tool.""" + with TemporaryDirectory() as temp_dir: + tool = FileSearchTool() + file_1 = Path(temp_dir) / "file1.txt" + file_2 = Path(temp_dir) / "file2.log" + file_1.write_text("File 1 content") + file_2.write_text("File 2 content") + matches = tool.run({"dir_path": temp_dir, "pattern": "*.txt"}).split("\n") + assert len(matches) == 1 + assert Path(matches[0]).name == "file1.txt" diff --git a/tests/unit_tests/tools/file_management/test_list_dir.py b/tests/unit_tests/tools/file_management/test_list_dir.py new file mode 100644 index 00000000..757973cc --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_list_dir.py @@ -0,0 +1,41 @@ +"""Test the DirectoryListing tool.""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +from langchain.tools.file_management.list_dir import ListDirectoryTool +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, +) + + +def test_list_directory_with_root_dir() -> None: + """Test the DirectoryListing tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = ListDirectoryTool(root_dir=temp_dir) + file_1 = Path(temp_dir) / "file1.txt" + file_2 = Path(temp_dir) / "file2.txt" + file_1.write_text("File 1 content") + file_2.write_text("File 2 content") + entries = tool.run({"dir_path": "."}).split("\n") + assert set(entries) == {"file1.txt", "file2.txt"} + + +def test_list_directory_errs_outside_root_dir() -> None: + """Test the DirectoryListing tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = ListDirectoryTool(root_dir=temp_dir) + result = tool.run({"dir_path": ".."}) + assert result == INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value="..") + + +def test_list_directory() -> None: + """Test the DirectoryListing tool.""" + with TemporaryDirectory() as temp_dir: + tool = ListDirectoryTool() + file_1 = Path(temp_dir) / "file1.txt" + file_2 = Path(temp_dir) / "file2.txt" + file_1.write_text("File 1 content") + file_2.write_text("File 2 content") + entries = tool.run({"dir_path": temp_dir}).split("\n") + assert set(entries) == {"file1.txt", "file2.txt"} diff --git a/tests/unit_tests/tools/file_management/test_move.py b/tests/unit_tests/tools/file_management/test_move.py new file mode 100644 index 00000000..ec3833af --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_move.py @@ -0,0 +1,52 @@ +"""Test the FileMove tool.""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +from langchain.tools.file_management.move import MoveFileTool +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, +) + + +def test_move_file_with_root_dir() -> None: + """Test the FileMove tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = MoveFileTool(root_dir=temp_dir) + source_file = Path(temp_dir) / "source.txt" + destination_file = Path(temp_dir) / "destination.txt" + source_file.write_text("Hello, world!") + tool.run({"source_path": "source.txt", "destination_path": "destination.txt"}) + assert not source_file.exists() + assert destination_file.exists() + assert destination_file.read_text() == "Hello, world!" + + +def test_move_file_errs_outside_root_dir() -> None: + """Test the FileMove tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = MoveFileTool(root_dir=temp_dir) + result = tool.run( + { + "source_path": "../source.txt", + "destination_path": "../destination.txt", + } + ) + assert result == INVALID_PATH_TEMPLATE.format( + arg_name="source_path", value="../source.txt" + ) + + +def test_move_file() -> None: + """Test the FileMove tool.""" + with TemporaryDirectory() as temp_dir: + tool = MoveFileTool() + source_file = Path(temp_dir) / "source.txt" + destination_file = Path(temp_dir) / "destination.txt" + source_file.write_text("Hello, world!") + tool.run( + {"source_path": str(source_file), "destination_path": str(destination_file)} + ) + assert not source_file.exists() + assert destination_file.exists() + assert destination_file.read_text() == "Hello, world!" diff --git a/tests/unit_tests/tools/file_management/test_toolkit.py b/tests/unit_tests/tools/file_management/test_toolkit.py new file mode 100644 index 00000000..34d71d07 --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_toolkit.py @@ -0,0 +1,48 @@ +"""Test the FileManagementToolkit.""" + +from tempfile import TemporaryDirectory + +import pytest + +from langchain.agents.agent_toolkits.file_management.toolkit import ( + FileManagementToolkit, +) +from langchain.tools.base import BaseTool + + +def test_file_toolkit_get_tools() -> None: + """Test the get_tools method of FileManagementToolkit.""" + with TemporaryDirectory() as temp_dir: + toolkit = FileManagementToolkit(root_dir=temp_dir) + tools = toolkit.get_tools() + assert len(tools) > 0 + assert all(isinstance(tool, BaseTool) for tool in tools) + + +def test_file_toolkit_get_tools_with_selection() -> None: + """Test the get_tools method of FileManagementToolkit with selected_tools.""" + with TemporaryDirectory() as temp_dir: + toolkit = FileManagementToolkit( + root_dir=temp_dir, selected_tools=["read_file", "write_file"] + ) + tools = toolkit.get_tools() + assert len(tools) == 2 + tool_names = [tool.name for tool in tools] + assert "read_file" in tool_names + assert "write_file" in tool_names + + +def test_file_toolkit_invalid_tool() -> None: + """Test the FileManagementToolkit with an invalid tool.""" + with TemporaryDirectory() as temp_dir: + with pytest.raises(ValueError): + FileManagementToolkit(root_dir=temp_dir, selected_tools=["invalid_tool"]) + + +def test_file_toolkit_root_dir() -> None: + """Test the FileManagementToolkit root_dir handling.""" + with TemporaryDirectory() as temp_dir: + toolkit = FileManagementToolkit(root_dir=temp_dir) + tools = toolkit.get_tools() + root_dirs = [tool.root_dir for tool in tools if hasattr(tool, "root_dir")] + assert all(root_dir == temp_dir for root_dir in root_dirs) diff --git a/tests/unit_tests/tools/file_management/test_utils.py b/tests/unit_tests/tools/file_management/test_utils.py index 31a6d010..4ad447bf 100644 --- a/tests/unit_tests/tools/file_management/test_utils.py +++ b/tests/unit_tests/tools/file_management/test_utils.py @@ -6,7 +6,10 @@ from tempfile import TemporaryDirectory import pytest -from langchain.tools.file_management.utils import get_validated_relative_path +from langchain.tools.file_management.utils import ( + FileValidationError, + get_validated_relative_path, +) def test_get_validated_relative_path_errs_on_absolute() -> None: @@ -14,7 +17,7 @@ def test_get_validated_relative_path_errs_on_absolute() -> None: root = Path(__file__).parent user_path = "/bin/bash" matches = f"Path {user_path} is outside of the allowed directory {root}" - with pytest.raises(ValueError, match=matches): + with pytest.raises(FileValidationError, match=matches): get_validated_relative_path(root, user_path) @@ -23,7 +26,7 @@ def test_get_validated_relative_path_errs_on_parent_dir() -> None: root = Path(__file__).parent user_path = "data/sub/../../../sibling" matches = f"Path {user_path} is outside of the allowed directory {root}" - with pytest.raises(ValueError, match=matches): + with pytest.raises(FileValidationError, match=matches): get_validated_relative_path(root, user_path) @@ -49,7 +52,7 @@ def test_get_validated_relative_path_errs_for_symlink_outside_root() -> None: matches = ( f"Path {user_path} is outside of the allowed directory {root.resolve()}" ) - with pytest.raises(ValueError, match=matches): + with pytest.raises(FileValidationError, match=matches): get_validated_relative_path(root, user_path) symlink_path.unlink() diff --git a/tests/unit_tests/tools/file_management/test_write.py b/tests/unit_tests/tools/file_management/test_write.py index f6222ef8..cbb81d8e 100644 --- a/tests/unit_tests/tools/file_management/test_write.py +++ b/tests/unit_tests/tools/file_management/test_write.py @@ -3,8 +3,9 @@ from pathlib import Path from tempfile import TemporaryDirectory -import pytest - +from langchain.tools.file_management.utils import ( + INVALID_PATH_TEMPLATE, +) from langchain.tools.file_management.write import WriteFileTool @@ -21,8 +22,10 @@ def test_write_file_errs_outside_root_dir() -> None: """Test the WriteFile tool when a root dir is specified.""" with TemporaryDirectory() as temp_dir: tool = WriteFileTool(root_dir=temp_dir) - with pytest.raises(ValueError): - tool.run({"file_path": "../file.txt", "text": "Hello, world!"}) + result = tool.run({"file_path": "../file.txt", "text": "Hello, world!"}) + assert result == INVALID_PATH_TEMPLATE.format( + arg_name="file_path", value="../file.txt" + ) def test_write_file() -> None: