mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: better input image path handling
throw exception for non-existent images
This commit is contained in:
parent
3012d28357
commit
95d3d08d27
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
.idea
|
||||
.vscode
|
||||
.DS_Store
|
||||
__pycache__
|
||||
outputs/*
|
||||
|
@ -89,8 +89,14 @@ def _imagine_cmd(
|
||||
|
||||
from imaginairy.utils import glob_expand_paths
|
||||
|
||||
num_prexpaned_init_images = len(init_images)
|
||||
init_images = glob_expand_paths(init_images)
|
||||
|
||||
if len(init_images) < num_prexpaned_init_images:
|
||||
raise ValueError(
|
||||
f"Could not find any images matching the glob pattern(s) {init_image}. Are you sure the file(s) exists?"
|
||||
)
|
||||
|
||||
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
|
||||
logger.info(
|
||||
f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images."
|
||||
|
@ -204,11 +204,12 @@ def shrink_list(items, max_size):
|
||||
|
||||
def glob_expand_paths(paths):
|
||||
import glob
|
||||
import os.path
|
||||
|
||||
expanded_paths = []
|
||||
for p in paths:
|
||||
if p.startswith("http"):
|
||||
expanded_paths.append(p)
|
||||
else:
|
||||
expanded_paths.extend(glob.glob(p))
|
||||
expanded_paths.extend(glob.glob(os.path.expanduser(p)))
|
||||
return expanded_paths
|
||||
|
@ -10,15 +10,13 @@ aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via omegaconf
|
||||
astroid==2.15.0
|
||||
astroid==2.15.2
|
||||
# via pylint
|
||||
async-timeout==4.0.2
|
||||
# via aiohttp
|
||||
attrs==22.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# pytest
|
||||
black==23.1.0
|
||||
# via aiohttp
|
||||
black==23.3.0
|
||||
# via -r requirements-dev.in
|
||||
certifi==2022.12.7
|
||||
# via requests
|
||||
@ -39,11 +37,11 @@ click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.0.7
|
||||
# via matplotlib
|
||||
coverage==7.2.2
|
||||
coverage==7.2.3
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
# via matplotlib
|
||||
diffusers==0.14.0
|
||||
diffusers==0.15.0
|
||||
# via imaginAIry (setup.py)
|
||||
dill==0.3.6
|
||||
# via pylint
|
||||
@ -55,26 +53,26 @@ facexlib==0.2.5
|
||||
# via imaginAIry (setup.py)
|
||||
fairscale==0.4.13
|
||||
# via imaginAIry (setup.py)
|
||||
filelock==3.10.0
|
||||
filelock==3.11.0
|
||||
# via
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.39.2
|
||||
fonttools==4.39.3
|
||||
# via matplotlib
|
||||
frozenlist==1.3.3
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2023.3.0
|
||||
fsspec[http]==2023.4.0
|
||||
# via pytorch-lightning
|
||||
ftfy==6.1.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
huggingface-hub==0.13.2
|
||||
huggingface-hub==0.13.4
|
||||
# via
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
@ -84,9 +82,9 @@ idna==3.4
|
||||
# via
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.26.0
|
||||
imageio==2.27.0
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==6.0.0
|
||||
importlib-metadata==6.3.0
|
||||
# via diffusers
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
@ -96,7 +94,7 @@ isort==5.12.0
|
||||
# pylint
|
||||
kiwisolver==1.4.4
|
||||
# via matplotlib
|
||||
kornia==0.6.10
|
||||
kornia==0.6.11
|
||||
# via imaginAIry (setup.py)
|
||||
lazy-object-proxy==1.9.0
|
||||
# via astroid
|
||||
@ -147,7 +145,7 @@ opencv-python==4.7.0.72
|
||||
# via
|
||||
# facexlib
|
||||
# imaginAIry (setup.py)
|
||||
packaging==23.0
|
||||
packaging==23.1
|
||||
# via
|
||||
# black
|
||||
# huggingface-hub
|
||||
@ -163,7 +161,7 @@ pathspec==0.10.3
|
||||
# via
|
||||
# black
|
||||
# pycln
|
||||
pillow==9.4.0
|
||||
pillow==9.5.0
|
||||
# via
|
||||
# diffusers
|
||||
# facexlib
|
||||
@ -171,7 +169,7 @@ pillow==9.4.0
|
||||
# imaginAIry (setup.py)
|
||||
# matplotlib
|
||||
# torchvision
|
||||
platformdirs==3.1.1
|
||||
platformdirs==3.2.0
|
||||
# via
|
||||
# black
|
||||
# pylint
|
||||
@ -193,22 +191,22 @@ pyflakes==3.0.1
|
||||
# via pylama
|
||||
pylama==8.4.1
|
||||
# via -r requirements-dev.in
|
||||
pylint==2.17.0
|
||||
pylint==2.17.2
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.0.9
|
||||
# via matplotlib
|
||||
pytest==7.2.2
|
||||
pytest==7.3.0
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pytest-randomly
|
||||
# pytest-sugar
|
||||
pytest-randomly==3.12.0
|
||||
# via -r requirements-dev.in
|
||||
pytest-sugar==0.9.6
|
||||
pytest-sugar==0.9.7
|
||||
# via -r requirements-dev.in
|
||||
python-dateutil==2.8.2
|
||||
# via matplotlib
|
||||
pytorch-lightning==1.9.4
|
||||
pytorch-lightning==1.9.5
|
||||
# via imaginAIry (setup.py)
|
||||
pyyaml==6.0
|
||||
# via
|
||||
@ -220,7 +218,7 @@ pyyaml==6.0
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
regex==2022.10.31
|
||||
regex==2023.3.23
|
||||
# via
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
@ -236,7 +234,7 @@ requests==2.28.2
|
||||
# transformers
|
||||
responses==0.23.1
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.0.256
|
||||
ruff==0.0.261
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
@ -245,7 +243,7 @@ scipy==1.10.1
|
||||
# facexlib
|
||||
# filterpy
|
||||
# torchdiffeq
|
||||
sentencepiece==0.1.97
|
||||
sentencepiece==0.1.98
|
||||
# via open-clip-torch
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
@ -253,18 +251,18 @@ snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
termcolor==2.2.0
|
||||
# via pytest-sugar
|
||||
timm==0.6.12
|
||||
timm==0.6.13
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
tokenizers==0.13.2
|
||||
tokenizers==0.13.3
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
# via
|
||||
# black
|
||||
# pylint
|
||||
# pytest
|
||||
tomlkit==0.11.6
|
||||
tomlkit==0.11.7
|
||||
# via
|
||||
# pycln
|
||||
# pylint
|
||||
@ -300,11 +298,11 @@ tqdm==4.65.0
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
transformers==4.27.1
|
||||
transformers==4.28.0
|
||||
# via imaginAIry (setup.py)
|
||||
typer==0.7.0
|
||||
# via pycln
|
||||
types-pyyaml==6.0.12.8
|
||||
types-pyyaml==6.0.12.9
|
||||
# via responses
|
||||
typing-extensions==4.5.0
|
||||
# via
|
||||
|
@ -11,6 +11,7 @@ from imaginairy.utils import (
|
||||
get_device,
|
||||
get_hardware_description,
|
||||
get_obj_from_str,
|
||||
glob_expand_paths,
|
||||
instantiate_from_config,
|
||||
)
|
||||
|
||||
@ -80,7 +81,50 @@ def test_instantiate_from_config():
|
||||
instantiate_from_config(config)
|
||||
|
||||
|
||||
#
|
||||
# def test_platform_appropriate_autocast():
|
||||
# with platform_appropriate_autocast("autocast"):
|
||||
# pass
|
||||
class TestGlobExpandPaths:
|
||||
def test_valid_file_paths(self, tmp_path):
|
||||
# create temporary file
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.touch()
|
||||
|
||||
# test function with valid file path
|
||||
result = glob_expand_paths([str(file_path)])
|
||||
assert result == [str(file_path)]
|
||||
|
||||
def test_valid_http_urls(self):
|
||||
# test function with valid http url
|
||||
result = glob_expand_paths(["http://www.example.com"])
|
||||
assert result == ["http://www.example.com"]
|
||||
|
||||
def test_file_paths_with_wildcards(self, tmp_path):
|
||||
# create temporary files
|
||||
file1 = tmp_path / "test1.txt"
|
||||
file1.touch()
|
||||
file2 = tmp_path / "test2.txt"
|
||||
file2.touch()
|
||||
|
||||
# test function with file path containing wildcard
|
||||
result = glob_expand_paths([str(tmp_path / "*.txt")])
|
||||
result.sort()
|
||||
assert result == [str(file1), str(file2)]
|
||||
|
||||
def test_empty_input(self):
|
||||
# test function with empty input list
|
||||
result = glob_expand_paths([])
|
||||
assert not result
|
||||
|
||||
def test_nonexistent_file_paths(self):
|
||||
# test function with non-existent file path
|
||||
result = glob_expand_paths(["/nonexistent/path"])
|
||||
assert not result
|
||||
|
||||
def test_user_expansion(self, monkeypatch, tmp_path):
|
||||
file1 = tmp_path / "test1.txt"
|
||||
file1.touch()
|
||||
|
||||
# monkeypatch os.path.expanduser to return a known path
|
||||
monkeypatch.setattr("os.path.expanduser", lambda x: str(tmp_path / "test1.txt"))
|
||||
|
||||
# test function with user expansion
|
||||
paths = ["~/file.txt"]
|
||||
assert glob_expand_paths(paths) == [str(tmp_path / "test1.txt")]
|
||||
|
Loading…
Reference in New Issue
Block a user