diff --git a/langchain/document_loaders/blob_loaders/file_system.py b/langchain/document_loaders/blob_loaders/file_system.py index 48c965aa..4705b2c5 100644 --- a/langchain/document_loaders/blob_loaders/file_system.py +++ b/langchain/document_loaders/blob_loaders/file_system.py @@ -1,9 +1,38 @@ """Use to load blobs from the local file system.""" from pathlib import Path -from typing import Iterable, Optional, Sequence, Union +from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union from langchain.document_loaders.blob_loaders.schema import Blob, BlobLoader +T = TypeVar("T") + + +def _make_iterator( + length_func: Callable[[], int], show_progress: bool = False +) -> Callable[[Iterable[T]], Iterator[T]]: + """Create a function that optionally wraps an iterable in tqdm.""" + if show_progress: + try: + from tqdm.auto import tqdm + except ImportError: + raise ImportError( + "You must install tqdm to use show_progress=True." + "You can install tqdm with `pip install tqdm`." + ) + + # Make sure to provide `total` here so that tqdm can show + # a progress bar that takes into account the total number of files. + def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]: + """Wrap an iterable in a tqdm progress bar.""" + return tqdm(iterable, total=length_func()) + + iterator = _with_tqdm + else: + iterator = iter # type: ignore + + return iterator + + # PUBLIC API @@ -26,6 +55,7 @@ class FileSystemBlobLoader(BlobLoader): *, glob: str = "**/[!.]*", suffixes: Optional[Sequence[str]] = None, + show_progress: bool = False, ) -> None: """Initialize with path to directory and how to glob over it. @@ -36,6 +66,9 @@ class FileSystemBlobLoader(BlobLoader): suffixes: Provide to keep only files with these suffixes Useful when wanting to keep files with different suffixes Suffixes must include the dot, e.g. ".txt" + show_progress: If true, will show a progress bar as the files are loaded. + This forces an iteration through all matching files + to count them prior to loading them. Examples: @@ -60,14 +93,33 @@ class FileSystemBlobLoader(BlobLoader): self.path = _path self.glob = glob self.suffixes = set(suffixes or []) + self.show_progress = show_progress def yield_blobs( self, ) -> Iterable[Blob]: """Yield blobs that match the requested pattern.""" + iterator = _make_iterator( + length_func=self.count_matching_files, show_progress=self.show_progress + ) + + for path in iterator(self._yield_paths()): + yield Blob.from_path(path) + + def _yield_paths(self) -> Iterable[Path]: + """Yield paths that match the requested pattern.""" paths = self.path.glob(self.glob) for path in paths: if path.is_file(): if self.suffixes and path.suffix not in self.suffixes: continue - yield Blob.from_path(str(path)) + yield path + + def count_matching_files(self) -> int: + """Count files that match the pattern without loading them.""" + # Carry out a full iteration to count the files without + # materializing anything expensive in memory. + num = 0 + for _ in self._yield_paths(): + num += 1 + return num diff --git a/pyproject.toml b/pyproject.toml index 32f20f69..526f9c33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,3 +184,17 @@ omit = [ [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +# --strict-markers will raise errors on unknown marks. +# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks +# +# https://docs.pytest.org/en/7.1.x/reference/reference.html +# --strict-config any warnings encountered while parsing the `pytest` +# section of the configuration file raise errors. +addopts = "--strict-markers --strict-config --durations=5" +# Registering custom markers. +# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers +markers = [ + "requires: mark tests as requiring a specific library" +] diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py new file mode 100644 index 00000000..db65d118 --- /dev/null +++ b/tests/unit_tests/conftest.py @@ -0,0 +1,44 @@ +"""Configuration for unit tests.""" +from importlib import util +from typing import Dict, Sequence + +import pytest +from pytest import Config, Function + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + required_pkgs_info[pkg] = util.find_spec(pkg) is not None + + if not required_pkgs_info[pkg]: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker(pytest.mark.skip(reason=f"requires pkg: `{pkg}`")) + break diff --git a/tests/unit_tests/document_loader/blob_loaders/test_filesystem_blob_loader.py b/tests/unit_tests/document_loader/blob_loaders/test_filesystem_blob_loader.py index 37bcd472..0c40bc08 100644 --- a/tests/unit_tests/document_loader/blob_loaders/test_filesystem_blob_loader.py +++ b/tests/unit_tests/document_loader/blob_loaders/test_filesystem_blob_loader.py @@ -91,6 +91,8 @@ def test_file_names_exist( loader = FileSystemBlobLoader(toy_dir, glob=glob, suffixes=suffixes) blobs = list(loader.yield_blobs()) + assert loader.count_matching_files() == len(relative_filenames) + file_names = sorted(str(blob.path) for blob in blobs) expected_filenames = sorted( @@ -99,3 +101,11 @@ def test_file_names_exist( ) assert file_names == expected_filenames + + +@pytest.mark.requires("tqdm") +def test_show_progress(toy_dir: str) -> None: + """Verify that file system loader works with a progress bar.""" + loader = FileSystemBlobLoader(toy_dir) + blobs = list(loader.yield_blobs()) + assert len(blobs) == loader.count_matching_files()