Update File Management Tools to Include Root Directory (#3112)

- Permit the specification of a `root_dir` to the read/write file tools
to specify a working directory
- Add validation for attempts to read/write outside the directory (e.g.,
through `../../` or symlinks or `/abs/path`'s that don't lie in the
correct path)
- Add some tests for all


One question is whether we should make a default root directory for
these? tradeoffs either way
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent a63bfb6c9f
commit 4adfd790f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,10 @@
from typing import Type from pathlib import Path
from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
class ReadFileInput(BaseModel): class ReadFileInput(BaseModel):
@ -15,10 +17,19 @@ class ReadFileTool(BaseTool):
name: str = "read_file" name: str = "read_file"
args_schema: Type[BaseModel] = ReadFileInput args_schema: Type[BaseModel] = ReadFileInput
description: str = "Read file from disk" description: str = "Read file from disk"
root_dir: Optional[str] = None
"""Directory to read file from.
If specified, raises an error for file_paths oustide root_dir."""
def _run(self, file_path: str) -> str: def _run(self, file_path: str) -> str:
read_path = (
get_validated_relative_path(Path(self.root_dir), file_path)
if self.root_dir
else Path(file_path)
)
try: try:
with open(file_path, "r", encoding="utf-8") as f: with read_path.open("r", encoding="utf-8") as f:
content = f.read() content = f.read()
return content return content
except Exception as e: except Exception as e:

@ -0,0 +1,26 @@
import sys
from pathlib import Path
def is_relative_to(path: Path, root: Path) -> bool:
"""Check if path is relative to root."""
if sys.version_info >= (3, 9):
# No need for a try/except block in Python 3.8+.
return path.is_relative_to(root)
try:
path.relative_to(root)
return True
except ValueError:
return False
def get_validated_relative_path(root: Path, user_path: str) -> Path:
"""Resolve a relative path, raising an error if not within the root directory."""
# Note, this still permits symlinks from outside that point within the root.
# Further validation would be needed if those are to be disallowed.
root = root.resolve()
full_path = (root / user_path).resolve()
if not is_relative_to(full_path, root):
raise ValueError(f"Path {user_path} is outside of the allowed directory {root}")
return full_path

@ -1,9 +1,10 @@
import os from pathlib import Path
from typing import Type from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
class WriteFileInput(BaseModel): class WriteFileInput(BaseModel):
@ -17,15 +18,22 @@ class WriteFileTool(BaseTool):
name: str = "write_file" name: str = "write_file"
args_schema: Type[BaseModel] = WriteFileInput args_schema: Type[BaseModel] = WriteFileInput
description: str = "Write file to disk" description: str = "Write file to disk"
root_dir: Optional[str] = None
"""Directory to write file to.
If specified, raises an error for file_paths oustide root_dir."""
def _run(self, file_path: str, text: str) -> str: def _run(self, file_path: str, text: str) -> str:
write_path = (
get_validated_relative_path(Path(self.root_dir), file_path)
if self.root_dir
else Path(file_path)
)
try: try:
directory = os.path.dirname(file_path) write_path.parent.mkdir(exist_ok=True, parents=False)
if not os.path.exists(directory) and directory: with write_path.open("w", encoding="utf-8") as f:
os.makedirs(directory)
with open(file_path, "w", encoding="utf-8") as f:
f.write(text) f.write(text)
return "File written to successfully." return f"File written successfully to {file_path}."
except Exception as e: except Exception as e:
return "Error: " + str(e) return "Error: " + str(e)

@ -0,0 +1,29 @@
"""Test the ReadFile tool."""
from pathlib import Path
from tempfile import TemporaryDirectory
from langchain.tools.file_management.read import ReadFileTool
def test_read_file_with_root_dir() -> None:
"""Test the ReadFile tool."""
with TemporaryDirectory() as temp_dir:
with (Path(temp_dir) / "file.txt").open("w") as f:
f.write("Hello, world!")
tool = ReadFileTool(root_dir=temp_dir)
result = tool.run("file.txt")
assert result == "Hello, world!"
# Check absolute files can still be passed if they lie within the root dir.
result = tool.run(str(Path(temp_dir) / "file.txt"))
assert result == "Hello, world!"
def test_read_file() -> None:
"""Test the ReadFile tool."""
with TemporaryDirectory() as temp_dir:
with (Path(temp_dir) / "file.txt").open("w") as f:
f.write("Hello, world!")
tool = ReadFileTool()
result = tool.run(str(Path(temp_dir) / "file.txt"))
assert result == "Hello, world!"

@ -0,0 +1,72 @@
"""Test the File Management utils."""
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from langchain.tools.file_management.utils import get_validated_relative_path
def test_get_validated_relative_path_errs_on_absolute() -> None:
"""Safely resolve a path."""
root = Path(__file__).parent
user_path = "/bin/bash"
matches = f"Path {user_path} is outside of the allowed directory {root}"
with pytest.raises(ValueError, match=matches):
get_validated_relative_path(root, user_path)
def test_get_validated_relative_path_errs_on_parent_dir() -> None:
"""Safely resolve a path."""
root = Path(__file__).parent
user_path = "data/sub/../../../sibling"
matches = f"Path {user_path} is outside of the allowed directory {root}"
with pytest.raises(ValueError, match=matches):
get_validated_relative_path(root, user_path)
def test_get_validated_relative_path() -> None:
"""Safely resolve a path."""
root = Path(__file__).parent
user_path = "data/sub/file.txt"
expected = root / user_path
result = get_validated_relative_path(root, user_path)
assert result == expected
def test_get_validated_relative_path_errs_for_symlink_outside_root() -> None:
"""Test that symlink pointing outside of root directory is not allowed."""
with TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
user_path = "symlink_outside_root"
outside_path = Path("/bin/bash")
symlink_path = root / user_path
symlink_path.symlink_to(outside_path)
matches = (
f"Path {user_path} is outside of the allowed directory {root.resolve()}"
)
with pytest.raises(ValueError, match=matches):
get_validated_relative_path(root, user_path)
symlink_path.unlink()
def test_get_validated_relative_path_for_symlink_inside_root() -> None:
"""Test that symlink pointing inside the root directory is allowed."""
with TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
user_path = "symlink_inside_root"
target_path = "data/sub/file.txt"
symlink_path = root / user_path
target_path_ = root / target_path
symlink_path.symlink_to(target_path_)
expected = target_path_.resolve()
result = get_validated_relative_path(root, user_path)
assert result == expected
symlink_path.unlink()

@ -0,0 +1,35 @@
"""Test the WriteFile tool."""
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from langchain.tools.file_management.write import WriteFileTool
def test_write_file_with_root_dir() -> None:
"""Test the WriteFile tool when a root dir is specified."""
with TemporaryDirectory() as temp_dir:
tool = WriteFileTool(root_dir=temp_dir)
tool.run({"file_path": "file.txt", "text": "Hello, world!"})
assert (Path(temp_dir) / "file.txt").exists()
assert (Path(temp_dir) / "file.txt").read_text() == "Hello, world!"
def test_write_file_errs_outside_root_dir() -> None:
"""Test the WriteFile tool when a root dir is specified."""
with TemporaryDirectory() as temp_dir:
tool = WriteFileTool(root_dir=temp_dir)
with pytest.raises(ValueError):
tool.run({"file_path": "../file.txt", "text": "Hello, world!"})
def test_write_file() -> None:
"""Test the WriteFile tool."""
with TemporaryDirectory() as temp_dir:
file_path = str(Path(temp_dir) / "file.txt")
tool = WriteFileTool()
tool.run({"file_path": file_path, "text": "Hello, world!"})
assert (Path(temp_dir) / "file.txt").exists()
assert (Path(temp_dir) / "file.txt").read_text() == "Hello, world!"
Loading…
Cancel
Save