forked from Archives/langchain
Update File Management Tools to Include Root Directory (#3112)
- Permit the specification of a `root_dir` to the read/write file tools to specify a working directory - Add validation for attempts to read/write outside the directory (e.g., through `../../` or symlinks or `/abs/path`'s that don't lie in the correct path) - Add some tests for all One question is whether we should make a default root directory for these? tradeoffs either way
This commit is contained in:
parent
a63bfb6c9f
commit
4adfd790f0
@ -1,8 +1,10 @@
|
|||||||
from typing import Type
|
from pathlib import Path
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||||
|
|
||||||
|
|
||||||
class ReadFileInput(BaseModel):
|
class ReadFileInput(BaseModel):
|
||||||
@ -15,10 +17,19 @@ class ReadFileTool(BaseTool):
|
|||||||
name: str = "read_file"
|
name: str = "read_file"
|
||||||
args_schema: Type[BaseModel] = ReadFileInput
|
args_schema: Type[BaseModel] = ReadFileInput
|
||||||
description: str = "Read file from disk"
|
description: str = "Read file from disk"
|
||||||
|
root_dir: Optional[str] = None
|
||||||
|
"""Directory to read file from.
|
||||||
|
|
||||||
|
If specified, raises an error for file_paths oustide root_dir."""
|
||||||
|
|
||||||
def _run(self, file_path: str) -> str:
|
def _run(self, file_path: str) -> str:
|
||||||
|
read_path = (
|
||||||
|
get_validated_relative_path(Path(self.root_dir), file_path)
|
||||||
|
if self.root_dir
|
||||||
|
else Path(file_path)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with read_path.open("r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
return content
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
26
langchain/tools/file_management/utils.py
Normal file
26
langchain/tools/file_management/utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def is_relative_to(path: Path, root: Path) -> bool:
|
||||||
|
"""Check if path is relative to root."""
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
|
# No need for a try/except block in Python 3.8+.
|
||||||
|
return path.is_relative_to(root)
|
||||||
|
try:
|
||||||
|
path.relative_to(root)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_validated_relative_path(root: Path, user_path: str) -> Path:
|
||||||
|
"""Resolve a relative path, raising an error if not within the root directory."""
|
||||||
|
# Note, this still permits symlinks from outside that point within the root.
|
||||||
|
# Further validation would be needed if those are to be disallowed.
|
||||||
|
root = root.resolve()
|
||||||
|
full_path = (root / user_path).resolve()
|
||||||
|
|
||||||
|
if not is_relative_to(full_path, root):
|
||||||
|
raise ValueError(f"Path {user_path} is outside of the allowed directory {root}")
|
||||||
|
return full_path
|
@ -1,9 +1,10 @@
|
|||||||
import os
|
from pathlib import Path
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||||
|
|
||||||
|
|
||||||
class WriteFileInput(BaseModel):
|
class WriteFileInput(BaseModel):
|
||||||
@ -17,15 +18,22 @@ class WriteFileTool(BaseTool):
|
|||||||
name: str = "write_file"
|
name: str = "write_file"
|
||||||
args_schema: Type[BaseModel] = WriteFileInput
|
args_schema: Type[BaseModel] = WriteFileInput
|
||||||
description: str = "Write file to disk"
|
description: str = "Write file to disk"
|
||||||
|
root_dir: Optional[str] = None
|
||||||
|
"""Directory to write file to.
|
||||||
|
|
||||||
|
If specified, raises an error for file_paths oustide root_dir."""
|
||||||
|
|
||||||
def _run(self, file_path: str, text: str) -> str:
|
def _run(self, file_path: str, text: str) -> str:
|
||||||
|
write_path = (
|
||||||
|
get_validated_relative_path(Path(self.root_dir), file_path)
|
||||||
|
if self.root_dir
|
||||||
|
else Path(file_path)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
directory = os.path.dirname(file_path)
|
write_path.parent.mkdir(exist_ok=True, parents=False)
|
||||||
if not os.path.exists(directory) and directory:
|
with write_path.open("w", encoding="utf-8") as f:
|
||||||
os.makedirs(directory)
|
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(text)
|
f.write(text)
|
||||||
return "File written to successfully."
|
return f"File written successfully to {file_path}."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "Error: " + str(e)
|
return "Error: " + str(e)
|
||||||
|
|
||||||
|
0
tests/unit_tests/tools/file_management/__init__.py
Normal file
0
tests/unit_tests/tools/file_management/__init__.py
Normal file
29
tests/unit_tests/tools/file_management/test_read.py
Normal file
29
tests/unit_tests/tools/file_management/test_read.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Test the ReadFile tool."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from langchain.tools.file_management.read import ReadFileTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_file_with_root_dir() -> None:
|
||||||
|
"""Test the ReadFile tool."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
with (Path(temp_dir) / "file.txt").open("w") as f:
|
||||||
|
f.write("Hello, world!")
|
||||||
|
tool = ReadFileTool(root_dir=temp_dir)
|
||||||
|
result = tool.run("file.txt")
|
||||||
|
assert result == "Hello, world!"
|
||||||
|
# Check absolute files can still be passed if they lie within the root dir.
|
||||||
|
result = tool.run(str(Path(temp_dir) / "file.txt"))
|
||||||
|
assert result == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_file() -> None:
|
||||||
|
"""Test the ReadFile tool."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
with (Path(temp_dir) / "file.txt").open("w") as f:
|
||||||
|
f.write("Hello, world!")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = tool.run(str(Path(temp_dir) / "file.txt"))
|
||||||
|
assert result == "Hello, world!"
|
72
tests/unit_tests/tools/file_management/test_utils.py
Normal file
72
tests/unit_tests/tools/file_management/test_utils.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Test the File Management utils."""
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validated_relative_path_errs_on_absolute() -> None:
|
||||||
|
"""Safely resolve a path."""
|
||||||
|
root = Path(__file__).parent
|
||||||
|
user_path = "/bin/bash"
|
||||||
|
matches = f"Path {user_path} is outside of the allowed directory {root}"
|
||||||
|
with pytest.raises(ValueError, match=matches):
|
||||||
|
get_validated_relative_path(root, user_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validated_relative_path_errs_on_parent_dir() -> None:
|
||||||
|
"""Safely resolve a path."""
|
||||||
|
root = Path(__file__).parent
|
||||||
|
user_path = "data/sub/../../../sibling"
|
||||||
|
matches = f"Path {user_path} is outside of the allowed directory {root}"
|
||||||
|
with pytest.raises(ValueError, match=matches):
|
||||||
|
get_validated_relative_path(root, user_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validated_relative_path() -> None:
|
||||||
|
"""Safely resolve a path."""
|
||||||
|
root = Path(__file__).parent
|
||||||
|
user_path = "data/sub/file.txt"
|
||||||
|
expected = root / user_path
|
||||||
|
result = get_validated_relative_path(root, user_path)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validated_relative_path_errs_for_symlink_outside_root() -> None:
|
||||||
|
"""Test that symlink pointing outside of root directory is not allowed."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
root = Path(temp_dir)
|
||||||
|
user_path = "symlink_outside_root"
|
||||||
|
|
||||||
|
outside_path = Path("/bin/bash")
|
||||||
|
symlink_path = root / user_path
|
||||||
|
symlink_path.symlink_to(outside_path)
|
||||||
|
|
||||||
|
matches = (
|
||||||
|
f"Path {user_path} is outside of the allowed directory {root.resolve()}"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match=matches):
|
||||||
|
get_validated_relative_path(root, user_path)
|
||||||
|
|
||||||
|
symlink_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validated_relative_path_for_symlink_inside_root() -> None:
|
||||||
|
"""Test that symlink pointing inside the root directory is allowed."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
root = Path(temp_dir)
|
||||||
|
user_path = "symlink_inside_root"
|
||||||
|
target_path = "data/sub/file.txt"
|
||||||
|
|
||||||
|
symlink_path = root / user_path
|
||||||
|
target_path_ = root / target_path
|
||||||
|
symlink_path.symlink_to(target_path_)
|
||||||
|
|
||||||
|
expected = target_path_.resolve()
|
||||||
|
result = get_validated_relative_path(root, user_path)
|
||||||
|
assert result == expected
|
||||||
|
symlink_path.unlink()
|
35
tests/unit_tests/tools/file_management/test_write.py
Normal file
35
tests/unit_tests/tools/file_management/test_write.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""Test the WriteFile tool."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.tools.file_management.write import WriteFileTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_file_with_root_dir() -> None:
|
||||||
|
"""Test the WriteFile tool when a root dir is specified."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
tool = WriteFileTool(root_dir=temp_dir)
|
||||||
|
tool.run({"file_path": "file.txt", "text": "Hello, world!"})
|
||||||
|
assert (Path(temp_dir) / "file.txt").exists()
|
||||||
|
assert (Path(temp_dir) / "file.txt").read_text() == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_file_errs_outside_root_dir() -> None:
|
||||||
|
"""Test the WriteFile tool when a root dir is specified."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
tool = WriteFileTool(root_dir=temp_dir)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tool.run({"file_path": "../file.txt", "text": "Hello, world!"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_file() -> None:
|
||||||
|
"""Test the WriteFile tool."""
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
file_path = str(Path(temp_dir) / "file.txt")
|
||||||
|
tool = WriteFileTool()
|
||||||
|
tool.run({"file_path": file_path, "text": "Hello, world!"})
|
||||||
|
assert (Path(temp_dir) / "file.txt").exists()
|
||||||
|
assert (Path(temp_dir) / "file.txt").read_text() == "Hello, world!"
|
Loading…
Reference in New Issue
Block a user