Add progress bar to filesystemblob loader, update pytest config for unit tests (#4212)

This PR adds:

* Option to show a tqdm progress bar when using the file system blob loader
* Update pytest run configuration to be stricter
* Adding a new marker that checks that required pkgs exist
parallel_dir_loader
Eugene Yurtsev 1 year ago committed by GitHub
parent f4c8502e61
commit aa11f7c89b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,38 @@
"""Use to load blobs from the local file system."""
from pathlib import Path
from typing import Iterable, Optional, Sequence, Union
from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union
from langchain.document_loaders.blob_loaders.schema import Blob, BlobLoader
T = TypeVar("T")
def _make_iterator(
length_func: Callable[[], int], show_progress: bool = False
) -> Callable[[Iterable[T]], Iterator[T]]:
"""Create a function that optionally wraps an iterable in tqdm."""
if show_progress:
try:
from tqdm.auto import tqdm
except ImportError:
raise ImportError(
"You must install tqdm to use show_progress=True."
"You can install tqdm with `pip install tqdm`."
)
# Make sure to provide `total` here so that tqdm can show
# a progress bar that takes into account the total number of files.
def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]:
"""Wrap an iterable in a tqdm progress bar."""
return tqdm(iterable, total=length_func())
iterator = _with_tqdm
else:
iterator = iter # type: ignore
return iterator
# PUBLIC API
@ -26,6 +55,7 @@ class FileSystemBlobLoader(BlobLoader):
*,
glob: str = "**/[!.]*",
suffixes: Optional[Sequence[str]] = None,
show_progress: bool = False,
) -> None:
"""Initialize with path to directory and how to glob over it.
@ -36,6 +66,9 @@ class FileSystemBlobLoader(BlobLoader):
suffixes: Provide to keep only files with these suffixes
Useful when wanting to keep files with different suffixes
Suffixes must include the dot, e.g. ".txt"
show_progress: If true, will show a progress bar as the files are loaded.
This forces an iteration through all matching files
to count them prior to loading them.
Examples:
@ -60,14 +93,33 @@ class FileSystemBlobLoader(BlobLoader):
self.path = _path
self.glob = glob
self.suffixes = set(suffixes or [])
self.show_progress = show_progress
def yield_blobs(
self,
) -> Iterable[Blob]:
"""Yield blobs that match the requested pattern."""
iterator = _make_iterator(
length_func=self.count_matching_files, show_progress=self.show_progress
)
for path in iterator(self._yield_paths()):
yield Blob.from_path(path)
def _yield_paths(self) -> Iterable[Path]:
"""Yield paths that match the requested pattern."""
paths = self.path.glob(self.glob)
for path in paths:
if path.is_file():
if self.suffixes and path.suffix not in self.suffixes:
continue
yield Blob.from_path(str(path))
yield path
def count_matching_files(self) -> int:
"""Count files that match the pattern without loading them."""
# Carry out a full iteration to count the files without
# materializing anything expensive in memory.
num = 0
for _ in self._yield_paths():
num += 1
return num

@ -184,3 +184,17 @@ omit = [
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library"
]

@ -0,0 +1,44 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence
import pytest
from pytest import Config, Function
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
required_pkgs_info[pkg] = util.find_spec(pkg) is not None
if not required_pkgs_info[pkg]:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(pytest.mark.skip(reason=f"requires pkg: `{pkg}`"))
break

@ -91,6 +91,8 @@ def test_file_names_exist(
loader = FileSystemBlobLoader(toy_dir, glob=glob, suffixes=suffixes)
blobs = list(loader.yield_blobs())
assert loader.count_matching_files() == len(relative_filenames)
file_names = sorted(str(blob.path) for blob in blobs)
expected_filenames = sorted(
@ -99,3 +101,11 @@ def test_file_names_exist(
)
assert file_names == expected_filenames
@pytest.mark.requires("tqdm")
def test_show_progress(toy_dir: str) -> None:
"""Verify that file system loader works with a progress bar."""
loader = FileSystemBlobLoader(toy_dir)
blobs = list(loader.yield_blobs())
assert len(blobs) == loader.count_matching_files()

Loading…
Cancel
Save