diff --git a/langchain/tools/file_management/read.py b/langchain/tools/file_management/read.py index 229bd63eb4..13d7dcb09d 100644 --- a/langchain/tools/file_management/read.py +++ b/langchain/tools/file_management/read.py @@ -1,8 +1,10 @@ -from typing import Type +from pathlib import Path +from typing import Optional, Type from pydantic import BaseModel, Field from langchain.tools.base import BaseTool +from langchain.tools.file_management.utils import get_validated_relative_path class ReadFileInput(BaseModel): @@ -15,10 +17,19 @@ class ReadFileTool(BaseTool): name: str = "read_file" args_schema: Type[BaseModel] = ReadFileInput description: str = "Read file from disk" + root_dir: Optional[str] = None + """Directory to read file from. + + If specified, raises an error for file_paths oustide root_dir.""" def _run(self, file_path: str) -> str: + read_path = ( + get_validated_relative_path(Path(self.root_dir), file_path) + if self.root_dir + else Path(file_path) + ) try: - with open(file_path, "r", encoding="utf-8") as f: + with read_path.open("r", encoding="utf-8") as f: content = f.read() return content except Exception as e: diff --git a/langchain/tools/file_management/utils.py b/langchain/tools/file_management/utils.py new file mode 100644 index 0000000000..6d3bb4c17d --- /dev/null +++ b/langchain/tools/file_management/utils.py @@ -0,0 +1,26 @@ +import sys +from pathlib import Path + + +def is_relative_to(path: Path, root: Path) -> bool: + """Check if path is relative to root.""" + if sys.version_info >= (3, 9): + # No need for a try/except block in Python 3.8+. + return path.is_relative_to(root) + try: + path.relative_to(root) + return True + except ValueError: + return False + + +def get_validated_relative_path(root: Path, user_path: str) -> Path: + """Resolve a relative path, raising an error if not within the root directory.""" + # Note, this still permits symlinks from outside that point within the root. + # Further validation would be needed if those are to be disallowed. + root = root.resolve() + full_path = (root / user_path).resolve() + + if not is_relative_to(full_path, root): + raise ValueError(f"Path {user_path} is outside of the allowed directory {root}") + return full_path diff --git a/langchain/tools/file_management/write.py b/langchain/tools/file_management/write.py index 4c3c1c0a34..ae4093b2ad 100644 --- a/langchain/tools/file_management/write.py +++ b/langchain/tools/file_management/write.py @@ -1,9 +1,10 @@ -import os -from typing import Type +from pathlib import Path +from typing import Optional, Type from pydantic import BaseModel, Field from langchain.tools.base import BaseTool +from langchain.tools.file_management.utils import get_validated_relative_path class WriteFileInput(BaseModel): @@ -17,15 +18,22 @@ class WriteFileTool(BaseTool): name: str = "write_file" args_schema: Type[BaseModel] = WriteFileInput description: str = "Write file to disk" + root_dir: Optional[str] = None + """Directory to write file to. + + If specified, raises an error for file_paths oustide root_dir.""" def _run(self, file_path: str, text: str) -> str: + write_path = ( + get_validated_relative_path(Path(self.root_dir), file_path) + if self.root_dir + else Path(file_path) + ) try: - directory = os.path.dirname(file_path) - if not os.path.exists(directory) and directory: - os.makedirs(directory) - with open(file_path, "w", encoding="utf-8") as f: + write_path.parent.mkdir(exist_ok=True, parents=False) + with write_path.open("w", encoding="utf-8") as f: f.write(text) - return "File written to successfully." + return f"File written successfully to {file_path}." except Exception as e: return "Error: " + str(e) diff --git a/tests/unit_tests/tools/file_management/__init__.py b/tests/unit_tests/tools/file_management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/tools/file_management/test_read.py b/tests/unit_tests/tools/file_management/test_read.py new file mode 100644 index 0000000000..c080ff9c25 --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_read.py @@ -0,0 +1,29 @@ +"""Test the ReadFile tool.""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +from langchain.tools.file_management.read import ReadFileTool + + +def test_read_file_with_root_dir() -> None: + """Test the ReadFile tool.""" + with TemporaryDirectory() as temp_dir: + with (Path(temp_dir) / "file.txt").open("w") as f: + f.write("Hello, world!") + tool = ReadFileTool(root_dir=temp_dir) + result = tool.run("file.txt") + assert result == "Hello, world!" + # Check absolute files can still be passed if they lie within the root dir. + result = tool.run(str(Path(temp_dir) / "file.txt")) + assert result == "Hello, world!" + + +def test_read_file() -> None: + """Test the ReadFile tool.""" + with TemporaryDirectory() as temp_dir: + with (Path(temp_dir) / "file.txt").open("w") as f: + f.write("Hello, world!") + tool = ReadFileTool() + result = tool.run(str(Path(temp_dir) / "file.txt")) + assert result == "Hello, world!" diff --git a/tests/unit_tests/tools/file_management/test_utils.py b/tests/unit_tests/tools/file_management/test_utils.py new file mode 100644 index 0000000000..31a6d01075 --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_utils.py @@ -0,0 +1,72 @@ +"""Test the File Management utils.""" + + +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from langchain.tools.file_management.utils import get_validated_relative_path + + +def test_get_validated_relative_path_errs_on_absolute() -> None: + """Safely resolve a path.""" + root = Path(__file__).parent + user_path = "/bin/bash" + matches = f"Path {user_path} is outside of the allowed directory {root}" + with pytest.raises(ValueError, match=matches): + get_validated_relative_path(root, user_path) + + +def test_get_validated_relative_path_errs_on_parent_dir() -> None: + """Safely resolve a path.""" + root = Path(__file__).parent + user_path = "data/sub/../../../sibling" + matches = f"Path {user_path} is outside of the allowed directory {root}" + with pytest.raises(ValueError, match=matches): + get_validated_relative_path(root, user_path) + + +def test_get_validated_relative_path() -> None: + """Safely resolve a path.""" + root = Path(__file__).parent + user_path = "data/sub/file.txt" + expected = root / user_path + result = get_validated_relative_path(root, user_path) + assert result == expected + + +def test_get_validated_relative_path_errs_for_symlink_outside_root() -> None: + """Test that symlink pointing outside of root directory is not allowed.""" + with TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + user_path = "symlink_outside_root" + + outside_path = Path("/bin/bash") + symlink_path = root / user_path + symlink_path.symlink_to(outside_path) + + matches = ( + f"Path {user_path} is outside of the allowed directory {root.resolve()}" + ) + with pytest.raises(ValueError, match=matches): + get_validated_relative_path(root, user_path) + + symlink_path.unlink() + + +def test_get_validated_relative_path_for_symlink_inside_root() -> None: + """Test that symlink pointing inside the root directory is allowed.""" + with TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + user_path = "symlink_inside_root" + target_path = "data/sub/file.txt" + + symlink_path = root / user_path + target_path_ = root / target_path + symlink_path.symlink_to(target_path_) + + expected = target_path_.resolve() + result = get_validated_relative_path(root, user_path) + assert result == expected + symlink_path.unlink() diff --git a/tests/unit_tests/tools/file_management/test_write.py b/tests/unit_tests/tools/file_management/test_write.py new file mode 100644 index 0000000000..f6222ef89c --- /dev/null +++ b/tests/unit_tests/tools/file_management/test_write.py @@ -0,0 +1,35 @@ +"""Test the WriteFile tool.""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from langchain.tools.file_management.write import WriteFileTool + + +def test_write_file_with_root_dir() -> None: + """Test the WriteFile tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = WriteFileTool(root_dir=temp_dir) + tool.run({"file_path": "file.txt", "text": "Hello, world!"}) + assert (Path(temp_dir) / "file.txt").exists() + assert (Path(temp_dir) / "file.txt").read_text() == "Hello, world!" + + +def test_write_file_errs_outside_root_dir() -> None: + """Test the WriteFile tool when a root dir is specified.""" + with TemporaryDirectory() as temp_dir: + tool = WriteFileTool(root_dir=temp_dir) + with pytest.raises(ValueError): + tool.run({"file_path": "../file.txt", "text": "Hello, world!"}) + + +def test_write_file() -> None: + """Test the WriteFile tool.""" + with TemporaryDirectory() as temp_dir: + file_path = str(Path(temp_dir) / "file.txt") + tool = WriteFileTool() + tool.run({"file_path": file_path, "text": "Hello, world!"}) + assert (Path(temp_dir) / "file.txt").exists() + assert (Path(temp_dir) / "file.txt").read_text() == "Hello, world!"