fix: better input image path handling

throw exception for non-existent images
This commit is contained in:
Bryce 2023-04-13 23:33:46 -07:00 committed by Bryce Drennan
parent 3012d28357
commit 95d3d08d27
5 changed files with 84 additions and 34 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.idea
.vscode
.DS_Store
__pycache__
outputs/*

View File

@ -89,8 +89,14 @@ def _imagine_cmd(
from imaginairy.utils import glob_expand_paths
num_prexpaned_init_images = len(init_images)
init_images = glob_expand_paths(init_images)
if len(init_images) < num_prexpaned_init_images:
raise ValueError(
f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?"
)
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
logger.info(
f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images."

View File

@ -204,11 +204,12 @@ def shrink_list(items, max_size):
def glob_expand_paths(paths):
import glob
import os.path
expanded_paths = []
for p in paths:
if p.startswith("http"):
expanded_paths.append(p)
else:
expanded_paths.extend(glob.glob(p))
expanded_paths.extend(glob.glob(os.path.expanduser(p)))
return expanded_paths

View File

@ -10,15 +10,13 @@ aiosignal==1.3.1
# via aiohttp
antlr4-python3-runtime==4.9.3
# via omegaconf
astroid==2.15.0
astroid==2.15.2
# via pylint
async-timeout==4.0.2
# via aiohttp
attrs==22.2.0
# via
# aiohttp
# pytest
black==23.1.0
# via aiohttp
black==23.3.0
# via -r requirements-dev.in
certifi==2022.12.7
# via requests
@ -39,11 +37,11 @@ click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.0.7
# via matplotlib
coverage==7.2.2
coverage==7.2.3
# via -r requirements-dev.in
cycler==0.11.0
# via matplotlib
diffusers==0.14.0
diffusers==0.15.0
# via imaginAIry (setup.py)
dill==0.3.6
# via pylint
@ -55,26 +53,26 @@ facexlib==0.2.5
# via imaginAIry (setup.py)
fairscale==0.4.13
# via imaginAIry (setup.py)
filelock==3.10.0
filelock==3.11.0
# via
# diffusers
# huggingface-hub
# transformers
filterpy==1.4.5
# via facexlib
fonttools==4.39.2
fonttools==4.39.3
# via matplotlib
frozenlist==1.3.3
# via
# aiohttp
# aiosignal
fsspec[http]==2023.3.0
fsspec[http]==2023.4.0
# via pytorch-lightning
ftfy==6.1.1
# via
# imaginAIry (setup.py)
# open-clip-torch
huggingface-hub==0.13.2
huggingface-hub==0.13.4
# via
# diffusers
# open-clip-torch
@ -84,9 +82,9 @@ idna==3.4
# via
# requests
# yarl
imageio==2.26.0
imageio==2.27.0
# via imaginAIry (setup.py)
importlib-metadata==6.0.0
importlib-metadata==6.3.0
# via diffusers
iniconfig==2.0.0
# via pytest
@ -96,7 +94,7 @@ isort==5.12.0
# pylint
kiwisolver==1.4.4
# via matplotlib
kornia==0.6.10
kornia==0.6.11
# via imaginAIry (setup.py)
lazy-object-proxy==1.9.0
# via astroid
@ -147,7 +145,7 @@ opencv-python==4.7.0.72
# via
# facexlib
# imaginAIry (setup.py)
packaging==23.0
packaging==23.1
# via
# black
# huggingface-hub
@ -163,7 +161,7 @@ pathspec==0.10.3
# via
# black
# pycln
pillow==9.4.0
pillow==9.5.0
# via
# diffusers
# facexlib
@ -171,7 +169,7 @@ pillow==9.4.0
# imaginAIry (setup.py)
# matplotlib
# torchvision
platformdirs==3.1.1
platformdirs==3.2.0
# via
# black
# pylint
@ -193,22 +191,22 @@ pyflakes==3.0.1
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.17.0
pylint==2.17.2
# via -r requirements-dev.in
pyparsing==3.0.9
# via matplotlib
pytest==7.2.2
pytest==7.3.0
# via
# -r requirements-dev.in
# pytest-randomly
# pytest-sugar
pytest-randomly==3.12.0
# via -r requirements-dev.in
pytest-sugar==0.9.6
pytest-sugar==0.9.7
# via -r requirements-dev.in
python-dateutil==2.8.2
# via matplotlib
pytorch-lightning==1.9.4
pytorch-lightning==1.9.5
# via imaginAIry (setup.py)
pyyaml==6.0
# via
@ -220,7 +218,7 @@ pyyaml==6.0
# responses
# timm
# transformers
regex==2022.10.31
regex==2023.3.23
# via
# diffusers
# open-clip-torch
@ -236,7 +234,7 @@ requests==2.28.2
# transformers
responses==0.23.1
# via -r requirements-dev.in
ruff==0.0.256
ruff==0.0.261
# via -r requirements-dev.in
safetensors==0.3.0
# via imaginAIry (setup.py)
@ -245,7 +243,7 @@ scipy==1.10.1
# facexlib
# filterpy
# torchdiffeq
sentencepiece==0.1.97
sentencepiece==0.1.98
# via open-clip-torch
six==1.16.0
# via python-dateutil
@ -253,18 +251,18 @@ snowballstemmer==2.2.0
# via pydocstyle
termcolor==2.2.0
# via pytest-sugar
timm==0.6.12
timm==0.6.13
# via
# imaginAIry (setup.py)
# open-clip-torch
tokenizers==0.13.2
tokenizers==0.13.3
# via transformers
tomli==2.0.1
# via
# black
# pylint
# pytest
tomlkit==0.11.6
tomlkit==0.11.7
# via
# pycln
# pylint
@ -300,11 +298,11 @@ tqdm==4.65.0
# open-clip-torch
# pytorch-lightning
# transformers
transformers==4.27.1
transformers==4.28.0
# via imaginAIry (setup.py)
typer==0.7.0
# via pycln
types-pyyaml==6.0.12.8
types-pyyaml==6.0.12.9
# via responses
typing-extensions==4.5.0
# via

View File

@ -11,6 +11,7 @@ from imaginairy.utils import (
get_device,
get_hardware_description,
get_obj_from_str,
glob_expand_paths,
instantiate_from_config,
)
@ -80,7 +81,50 @@ def test_instantiate_from_config():
instantiate_from_config(config)
#
# def test_platform_appropriate_autocast():
# with platform_appropriate_autocast("autocast"):
# pass
class TestGlobExpandPaths:
def test_valid_file_paths(self, tmp_path):
# create temporary file
file_path = tmp_path / "test.txt"
file_path.touch()
# test function with valid file path
result = glob_expand_paths([str(file_path)])
assert result == [str(file_path)]
def test_valid_http_urls(self):
# test function with valid http url
result = glob_expand_paths(["http://www.example.com"])
assert result == ["http://www.example.com"]
def test_file_paths_with_wildcards(self, tmp_path):
# create temporary files
file1 = tmp_path / "test1.txt"
file1.touch()
file2 = tmp_path / "test2.txt"
file2.touch()
# test function with file path containing wildcard
result = glob_expand_paths([str(tmp_path / "*.txt")])
result.sort()
assert result == [str(file1), str(file2)]
def test_empty_input(self):
# test function with empty input list
result = glob_expand_paths([])
assert not result
def test_nonexistent_file_paths(self):
# test function with non-existent file path
result = glob_expand_paths(["/nonexistent/path"])
assert not result
def test_user_expansion(self, monkeypatch, tmp_path):
file1 = tmp_path / "test1.txt"
file1.touch()
# monkeypatch os.path.expanduser to return a known path
monkeypatch.setattr("os.path.expanduser", lambda x: str(tmp_path / "test1.txt"))
# test function with user expansion
paths = ["~/file.txt"]
assert glob_expand_paths(paths) == [str(tmp_path / "test1.txt")]