From 95d3d08d27eb24d4a7743928a2ecacbffac41d8d Mon Sep 17 00:00:00 2001 From: Bryce Date: Thu, 13 Apr 2023 23:33:46 -0700 Subject: [PATCH] fix: better input image path handling throw exception for non-existent images --- .gitignore | 1 + imaginairy/cli/shared.py | 6 +++++ imaginairy/utils.py | 3 ++- requirements-dev.txt | 56 +++++++++++++++++++--------------------- tests/test_utils.py | 52 ++++++++++++++++++++++++++++++++++--- 5 files changed, 84 insertions(+), 34 deletions(-) diff --git a/.gitignore b/.gitignore index 4f10a2c..5158cd8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea +.vscode .DS_Store __pycache__ outputs/* diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index d9ac30c..a0f5c0e 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -89,8 +89,14 @@ def _imagine_cmd( from imaginairy.utils import glob_expand_paths + num_prexpaned_init_images = len(init_images) init_images = glob_expand_paths(init_images) + if len(init_images) < num_prexpaned_init_images: + raise ValueError( + f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?" + ) + total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats logger.info( f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images." diff --git a/imaginairy/utils.py b/imaginairy/utils.py index be329b8..715d840 100644 --- a/imaginairy/utils.py +++ b/imaginairy/utils.py @@ -204,11 +204,12 @@ def shrink_list(items, max_size): def glob_expand_paths(paths): import glob + import os.path expanded_paths = [] for p in paths: if p.startswith("http"): expanded_paths.append(p) else: - expanded_paths.extend(glob.glob(p)) + expanded_paths.extend(glob.glob(os.path.expanduser(p))) return expanded_paths diff --git a/requirements-dev.txt b/requirements-dev.txt index af2d89b..eba25fc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,15 +10,13 @@ aiosignal==1.3.1 # via aiohttp antlr4-python3-runtime==4.9.3 # via omegaconf -astroid==2.15.0 +astroid==2.15.2 # via pylint async-timeout==4.0.2 # via aiohttp attrs==22.2.0 - # via - # aiohttp - # pytest -black==23.1.0 + # via aiohttp +black==23.3.0 # via -r requirements-dev.in certifi==2022.12.7 # via requests @@ -39,11 +37,11 @@ click-shell==2.1 # via imaginAIry (setup.py) contourpy==1.0.7 # via matplotlib -coverage==7.2.2 +coverage==7.2.3 # via -r requirements-dev.in cycler==0.11.0 # via matplotlib -diffusers==0.14.0 +diffusers==0.15.0 # via imaginAIry (setup.py) dill==0.3.6 # via pylint @@ -55,26 +53,26 @@ facexlib==0.2.5 # via imaginAIry (setup.py) fairscale==0.4.13 # via imaginAIry (setup.py) -filelock==3.10.0 +filelock==3.11.0 # via # diffusers # huggingface-hub # transformers filterpy==1.4.5 # via facexlib -fonttools==4.39.2 +fonttools==4.39.3 # via matplotlib frozenlist==1.3.3 # via # aiohttp # aiosignal -fsspec[http]==2023.3.0 +fsspec[http]==2023.4.0 # via pytorch-lightning ftfy==6.1.1 # via # imaginAIry (setup.py) # open-clip-torch -huggingface-hub==0.13.2 +huggingface-hub==0.13.4 # via # diffusers # open-clip-torch @@ -84,9 +82,9 @@ idna==3.4 # via # requests # yarl -imageio==2.26.0 +imageio==2.27.0 # via imaginAIry (setup.py) -importlib-metadata==6.0.0 +importlib-metadata==6.3.0 # via diffusers iniconfig==2.0.0 # via pytest @@ -96,7 +94,7 @@ isort==5.12.0 # pylint kiwisolver==1.4.4 # via matplotlib -kornia==0.6.10 +kornia==0.6.11 # via imaginAIry (setup.py) lazy-object-proxy==1.9.0 # via astroid @@ -147,7 +145,7 @@ opencv-python==4.7.0.72 # via # facexlib # imaginAIry (setup.py) -packaging==23.0 +packaging==23.1 # via # black # huggingface-hub @@ -163,7 +161,7 @@ pathspec==0.10.3 # via # black # pycln -pillow==9.4.0 +pillow==9.5.0 # via # diffusers # facexlib @@ -171,7 +169,7 @@ pillow==9.4.0 # imaginAIry (setup.py) # matplotlib # torchvision -platformdirs==3.1.1 +platformdirs==3.2.0 # via # black # pylint @@ -193,22 +191,22 @@ pyflakes==3.0.1 # via pylama pylama==8.4.1 # via -r requirements-dev.in -pylint==2.17.0 +pylint==2.17.2 # via -r requirements-dev.in pyparsing==3.0.9 # via matplotlib -pytest==7.2.2 +pytest==7.3.0 # via # -r requirements-dev.in # pytest-randomly # pytest-sugar pytest-randomly==3.12.0 # via -r requirements-dev.in -pytest-sugar==0.9.6 +pytest-sugar==0.9.7 # via -r requirements-dev.in python-dateutil==2.8.2 # via matplotlib -pytorch-lightning==1.9.4 +pytorch-lightning==1.9.5 # via imaginAIry (setup.py) pyyaml==6.0 # via @@ -220,7 +218,7 @@ pyyaml==6.0 # responses # timm # transformers -regex==2022.10.31 +regex==2023.3.23 # via # diffusers # open-clip-torch @@ -236,7 +234,7 @@ requests==2.28.2 # transformers responses==0.23.1 # via -r requirements-dev.in -ruff==0.0.256 +ruff==0.0.261 # via -r requirements-dev.in safetensors==0.3.0 # via imaginAIry (setup.py) @@ -245,7 +243,7 @@ scipy==1.10.1 # facexlib # filterpy # torchdiffeq -sentencepiece==0.1.97 +sentencepiece==0.1.98 # via open-clip-torch six==1.16.0 # via python-dateutil @@ -253,18 +251,18 @@ snowballstemmer==2.2.0 # via pydocstyle termcolor==2.2.0 # via pytest-sugar -timm==0.6.12 +timm==0.6.13 # via # imaginAIry (setup.py) # open-clip-torch -tokenizers==0.13.2 +tokenizers==0.13.3 # via transformers tomli==2.0.1 # via # black # pylint # pytest -tomlkit==0.11.6 +tomlkit==0.11.7 # via # pycln # pylint @@ -300,11 +298,11 @@ tqdm==4.65.0 # open-clip-torch # pytorch-lightning # transformers -transformers==4.27.1 +transformers==4.28.0 # via imaginAIry (setup.py) typer==0.7.0 # via pycln -types-pyyaml==6.0.12.8 +types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 # via diff --git a/tests/test_utils.py b/tests/test_utils.py index 9bfd052..293bede 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,6 +11,7 @@ from imaginairy.utils import ( get_device, get_hardware_description, get_obj_from_str, + glob_expand_paths, instantiate_from_config, ) @@ -80,7 +81,50 @@ def test_instantiate_from_config(): instantiate_from_config(config) -# -# def test_platform_appropriate_autocast(): -# with platform_appropriate_autocast("autocast"): -# pass +class TestGlobExpandPaths: + def test_valid_file_paths(self, tmp_path): + # create temporary file + file_path = tmp_path / "test.txt" + file_path.touch() + + # test function with valid file path + result = glob_expand_paths([str(file_path)]) + assert result == [str(file_path)] + + def test_valid_http_urls(self): + # test function with valid http url + result = glob_expand_paths(["http://www.example.com"]) + assert result == ["http://www.example.com"] + + def test_file_paths_with_wildcards(self, tmp_path): + # create temporary files + file1 = tmp_path / "test1.txt" + file1.touch() + file2 = tmp_path / "test2.txt" + file2.touch() + + # test function with file path containing wildcard + result = glob_expand_paths([str(tmp_path / "*.txt")]) + result.sort() + assert result == [str(file1), str(file2)] + + def test_empty_input(self): + # test function with empty input list + result = glob_expand_paths([]) + assert not result + + def test_nonexistent_file_paths(self): + # test function with non-existent file path + result = glob_expand_paths(["/nonexistent/path"]) + assert not result + + def test_user_expansion(self, monkeypatch, tmp_path): + file1 = tmp_path / "test1.txt" + file1.touch() + + # monkeypatch os.path.expanduser to return a known path + monkeypatch.setattr("os.path.expanduser", lambda x: str(tmp_path / "test1.txt")) + + # test function with user expansion + paths = ["~/file.txt"] + assert glob_expand_paths(paths) == [str(tmp_path / "test1.txt")]