mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
build: update requirements
This commit is contained in:
parent
26c0a5608b
commit
a512ed7032
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
- name: Lint
|
||||
run: |
|
||||
echo "::add-matcher::.github/pylama_matcher.json"
|
||||
ruff --config tests/ruff.toml .
|
||||
ruff check --config tests/ruff.toml .
|
||||
test-gpu:
|
||||
runs-on: nvidia-4090
|
||||
steps:
|
||||
|
@ -51,16 +51,16 @@ def mod_get_invoke(command):
|
||||
return False
|
||||
|
||||
invoke_ = update_wrapper(invoke_, command.callback)
|
||||
invoke_.__name__ = "do_%s" % command.name
|
||||
invoke_.__name__ = f"do_{command.name}"
|
||||
return invoke_
|
||||
|
||||
|
||||
class ModClickShell(ClickShell):
|
||||
def add_command(self, cmd, name):
|
||||
# Use the MethodType to add these as bound methods to our current instance
|
||||
setattr(self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self))
|
||||
setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self))
|
||||
setattr(self, "complete_%s" % name, get_method_type(get_complete(cmd), self))
|
||||
setattr(self, f"do_{name}", get_method_type(mod_get_invoke(cmd), self))
|
||||
setattr(self, f"help_{name}", get_method_type(get_help(cmd), self))
|
||||
setattr(self, f"complete_{name}", get_method_type(get_complete(cmd), self))
|
||||
|
||||
|
||||
class ModShell(Shell):
|
||||
|
@ -310,12 +310,12 @@ def interpolate_video_file(
|
||||
if montage:
|
||||
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
for mid in output:
|
||||
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
|
||||
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0) # type: ignore
|
||||
write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
|
||||
else:
|
||||
write_buffer.put(lastframe)
|
||||
for mid in output:
|
||||
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
|
||||
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0) # type: ignore
|
||||
write_buffer.put(mid[:h, :w])
|
||||
pbar.update(1)
|
||||
lastframe = frame
|
||||
|
@ -15,7 +15,7 @@ def create_canny_edges(img: "Tensor") -> "Tensor":
|
||||
|
||||
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
img = einops.rearrange(img[0], "c h w -> h w c")
|
||||
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
|
||||
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze() # type: ignore
|
||||
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8) # type: ignore
|
||||
|
||||
if len(blurred.shape) > 2:
|
||||
|
@ -363,7 +363,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
if attn_type == "vanilla-xformers":
|
||||
# print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
if type == "memory-efficient-cross-attn":
|
||||
if attn_type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
if attn_type == "none":
|
||||
|
@ -32,7 +32,7 @@ class NLayerDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
norm_layer = nn.BatchNorm2d if not use_actnorm else ActNorm
|
||||
if (
|
||||
type(norm_layer) == functools.partial
|
||||
type(norm_layer) == functools.partial # noqa
|
||||
): # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||
else:
|
||||
|
@ -308,7 +308,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
# f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
||||
# )
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
elif attn_type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
|
@ -67,7 +67,7 @@ def make_bounce_animation(
|
||||
def _ensure_pillow_images(
|
||||
imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]",
|
||||
) -> "List[Image.Image]":
|
||||
converted_frames: "List[Image.Image]" = []
|
||||
converted_frames: List[Image.Image] = []
|
||||
for frame in imgs:
|
||||
if isinstance(frame, torch.Tensor):
|
||||
converted_frames.append(model_latents_to_pillow_imgs(frame)[0])
|
||||
|
@ -72,7 +72,7 @@ class WeightMap:
|
||||
return source_keys.issubset(self.all_valid_prefixes)
|
||||
|
||||
def cast_weights(self, source_weights) -> dict[str, "Tensor"]:
|
||||
converted_state_dict: dict[str, "Tensor"] = {}
|
||||
converted_state_dict: dict[str, Tensor] = {}
|
||||
for source_key in source_weights:
|
||||
try:
|
||||
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
|
||||
|
@ -4,19 +4,21 @@
|
||||
#
|
||||
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
|
||||
#
|
||||
annotated-types==0.6.0
|
||||
accelerate==0.34.2
|
||||
# via imaginAIry (setup.py)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via omegaconf
|
||||
anyio==4.3.0
|
||||
anyio==4.6.0
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
babel==2.14.0
|
||||
babel==2.16.0
|
||||
# via mkdocs-material
|
||||
build==1.2.1
|
||||
build==1.2.2
|
||||
# via pip-tools
|
||||
certifi==2024.2.2
|
||||
certifi==2024.8.30
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
@ -41,37 +43,37 @@ colorama==0.4.6
|
||||
# via
|
||||
# griffe
|
||||
# mkdocs-material
|
||||
coverage==7.4.4
|
||||
coverage==7.6.1
|
||||
# via -r requirements-dev.in
|
||||
diffusers==0.27.2
|
||||
diffusers==0.30.3
|
||||
# via imaginAIry (setup.py)
|
||||
einops==0.7.0
|
||||
einops==0.8.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# spandrel
|
||||
exceptiongroup==1.2.0
|
||||
exceptiongroup==1.2.2
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
fastapi==0.110.1
|
||||
fastapi==0.115.0
|
||||
# via imaginAIry (setup.py)
|
||||
filelock==3.13.4
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
fsspec==2024.3.1
|
||||
fsspec==2024.9.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
ftfy==6.2.0
|
||||
ftfy==6.2.3
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
ghp-import==2.1.0
|
||||
# via mkdocs
|
||||
griffe==0.42.2
|
||||
griffe==1.3.1
|
||||
# via mkdocstrings-python
|
||||
h11==0.14.0
|
||||
# via
|
||||
@ -79,37 +81,38 @@ h11==0.14.0
|
||||
# uvicorn
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httpx==0.27.0
|
||||
httpx==0.27.2
|
||||
# via -r requirements-dev.in
|
||||
huggingface-hub==0.22.2
|
||||
huggingface-hub==0.25.0
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
idna==3.7
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
importlib-metadata==7.1.0
|
||||
importlib-metadata==8.5.0
|
||||
# via diffusers
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
jaxtyping==0.2.28
|
||||
jaxtyping==0.2.34
|
||||
# via imaginAIry (setup.py)
|
||||
jinja2==3.1.3
|
||||
jinja2==3.1.4
|
||||
# via
|
||||
# mkdocs
|
||||
# mkdocs-material
|
||||
# mkdocstrings
|
||||
# torch
|
||||
kornia==0.7.2
|
||||
kornia==0.7.3
|
||||
# via imaginAIry (setup.py)
|
||||
kornia-rs==0.1.3
|
||||
kornia-rs==0.1.5
|
||||
# via kornia
|
||||
markdown==3.6
|
||||
markdown==3.7
|
||||
# via
|
||||
# mkdocs
|
||||
# mkdocs-autorefs
|
||||
@ -124,53 +127,65 @@ markupsafe==2.1.5
|
||||
# mkdocs-autorefs
|
||||
# mkdocstrings
|
||||
mergedeep==1.3.4
|
||||
# via mkdocs
|
||||
mkdocs==1.5.3
|
||||
# via
|
||||
# mkdocs
|
||||
# mkdocs-get-deps
|
||||
mkdocs==1.6.1
|
||||
# via
|
||||
# mkdocs-autorefs
|
||||
# mkdocs-material
|
||||
# mkdocstrings
|
||||
mkdocs-autorefs==1.0.1
|
||||
# via mkdocstrings
|
||||
mkdocs-autorefs==1.2.0
|
||||
# via
|
||||
# mkdocstrings
|
||||
# mkdocstrings-python
|
||||
mkdocs-click==0.8.1
|
||||
# via -r requirements-dev.in
|
||||
mkdocs-material==9.5.18
|
||||
mkdocs-get-deps==0.2.0
|
||||
# via mkdocs
|
||||
mkdocs-material==9.5.36
|
||||
# via -r requirements-dev.in
|
||||
mkdocs-material-extensions==1.3.1
|
||||
# via mkdocs-material
|
||||
mkdocstrings[python]==0.24.3
|
||||
mkdocstrings[python]==0.26.1
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# mkdocstrings-python
|
||||
mkdocstrings-python==1.9.2
|
||||
mkdocstrings-python==1.11.1
|
||||
# via mkdocstrings
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mypy==1.9.0
|
||||
mypy==1.11.2
|
||||
# via -r requirements-dev.in
|
||||
mypy-extensions==1.0.0
|
||||
# via mypy
|
||||
networkx==3.3
|
||||
# via torch
|
||||
ninja==1.11.1.1
|
||||
# via optimum-quanto
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -c tests/constraints.txt
|
||||
# accelerate
|
||||
# diffusers
|
||||
# imaginAIry (setup.py)
|
||||
# jaxtyping
|
||||
# opencv-python
|
||||
# optimum-quanto
|
||||
# scipy
|
||||
# spandrel
|
||||
# torchvision
|
||||
# transformers
|
||||
omegaconf==2.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
open-clip-torch==2.24.0
|
||||
open-clip-torch==2.26.1
|
||||
# via imaginAIry (setup.py)
|
||||
opencv-python==4.9.0.80
|
||||
opencv-python==4.10.0.84
|
||||
# via imaginAIry (setup.py)
|
||||
packaging==24.0
|
||||
optimum-quanto==0.2.4
|
||||
# via imaginAIry (setup.py)
|
||||
packaging==24.1
|
||||
# via
|
||||
# accelerate
|
||||
# build
|
||||
# huggingface-hub
|
||||
# kornia
|
||||
@ -178,54 +193,54 @@ packaging==24.0
|
||||
# pytest
|
||||
# pytest-sugar
|
||||
# transformers
|
||||
paginate==0.5.6
|
||||
paginate==0.5.7
|
||||
# via mkdocs-material
|
||||
pathspec==0.12.1
|
||||
# via mkdocs
|
||||
pillow==10.3.0
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# diffusers
|
||||
# imaginAIry (setup.py)
|
||||
# torchvision
|
||||
pip-tools==7.4.1
|
||||
# via -r requirements-dev.in
|
||||
platformdirs==4.2.0
|
||||
platformdirs==4.3.6
|
||||
# via
|
||||
# mkdocs
|
||||
# mkdocs-get-deps
|
||||
# mkdocstrings
|
||||
pluggy==1.4.0
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
protobuf==5.26.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
psutil==5.9.8
|
||||
protobuf==5.28.2
|
||||
# via imaginAIry (setup.py)
|
||||
pydantic==2.7.0
|
||||
psutil==6.0.0
|
||||
# via
|
||||
# accelerate
|
||||
# imaginAIry (setup.py)
|
||||
pydantic==2.9.2
|
||||
# via
|
||||
# fastapi
|
||||
# imaginAIry (setup.py)
|
||||
pydantic-core==2.18.1
|
||||
pydantic-core==2.23.4
|
||||
# via pydantic
|
||||
pygments==2.17.2
|
||||
pygments==2.18.0
|
||||
# via mkdocs-material
|
||||
pymdown-extensions==10.7.1
|
||||
pymdown-extensions==10.9
|
||||
# via
|
||||
# mkdocs-material
|
||||
# mkdocstrings
|
||||
pyparsing==3.1.2
|
||||
pyparsing==3.1.4
|
||||
# via imaginAIry (setup.py)
|
||||
pyproject-hooks==1.0.0
|
||||
pyproject-hooks==1.1.0
|
||||
# via
|
||||
# build
|
||||
# pip-tools
|
||||
pytest==8.1.1
|
||||
pytest==8.3.3
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pytest-asyncio
|
||||
# pytest-randomly
|
||||
# pytest-sugar
|
||||
pytest-asyncio==0.23.6
|
||||
pytest-asyncio==0.24.0
|
||||
# via -r requirements-dev.in
|
||||
pytest-randomly==3.15.0
|
||||
# via -r requirements-dev.in
|
||||
@ -233,10 +248,12 @@ pytest-sugar==1.0.0
|
||||
# via -r requirements-dev.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via ghp-import
|
||||
pyyaml==6.0.1
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# huggingface-hub
|
||||
# mkdocs
|
||||
# mkdocs-get-deps
|
||||
# omegaconf
|
||||
# pymdown-extensions
|
||||
# pyyaml-env-tag
|
||||
@ -245,13 +262,13 @@ pyyaml==6.0.1
|
||||
# transformers
|
||||
pyyaml-env-tag==0.1
|
||||
# via mkdocs
|
||||
regex==2024.4.16
|
||||
regex==2024.9.11
|
||||
# via
|
||||
# diffusers
|
||||
# mkdocs-material
|
||||
# open-clip-torch
|
||||
# transformers
|
||||
requests==2.31.0
|
||||
requests==2.32.3
|
||||
# via
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
@ -259,88 +276,91 @@ requests==2.31.0
|
||||
# mkdocs-material
|
||||
# responses
|
||||
# transformers
|
||||
responses==0.25.0
|
||||
responses==0.25.3
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.3.7
|
||||
ruff==0.6.7
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.4.3
|
||||
safetensors==0.4.5
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# imaginAIry (setup.py)
|
||||
# optimum-quanto
|
||||
# spandrel
|
||||
# timm
|
||||
# transformers
|
||||
scipy==1.13.0
|
||||
scipy==1.14.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# torchdiffeq
|
||||
sentencepiece==0.2.0
|
||||
# via open-clip-torch
|
||||
# via imaginAIry (setup.py)
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
spandrel==0.3.1
|
||||
spandrel==0.4.0
|
||||
# via imaginAIry (setup.py)
|
||||
starlette==0.37.2
|
||||
starlette==0.38.5
|
||||
# via fastapi
|
||||
sympy==1.12
|
||||
sympy==1.13.3
|
||||
# via torch
|
||||
termcolor==2.4.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytest-sugar
|
||||
timm==0.9.16
|
||||
timm==1.0.9
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
tokenizers==0.15.2
|
||||
tokenizers==0.19.1
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
# via
|
||||
# build
|
||||
# mypy
|
||||
# pip-tools
|
||||
# pyproject-hooks
|
||||
# pytest
|
||||
torch==2.2.2
|
||||
torch==2.4.1
|
||||
# via
|
||||
# accelerate
|
||||
# imaginAIry (setup.py)
|
||||
# kornia
|
||||
# open-clip-torch
|
||||
# optimum-quanto
|
||||
# spandrel
|
||||
# timm
|
||||
# torchdiffeq
|
||||
# torchvision
|
||||
torchdiffeq==0.2.3
|
||||
torchdiffeq==0.2.4
|
||||
# via imaginAIry (setup.py)
|
||||
torchvision==0.17.2
|
||||
torchvision==0.19.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
# spandrel
|
||||
# timm
|
||||
tqdm==4.66.2
|
||||
tqdm==4.66.5
|
||||
# via
|
||||
# huggingface-hub
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
# transformers
|
||||
transformers==4.39.3
|
||||
transformers==4.44.2
|
||||
# via imaginAIry (setup.py)
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
types-pillow==10.2.0.20240415
|
||||
types-pillow==10.2.0.20240822
|
||||
# via -r requirements-dev.in
|
||||
types-psutil==5.9.5.20240316
|
||||
types-psutil==6.0.0.20240901
|
||||
# via -r requirements-dev.in
|
||||
types-requests==2.31.0.20240406
|
||||
types-requests==2.32.0.20240914
|
||||
# via -r requirements-dev.in
|
||||
types-tqdm==4.66.0.20240417
|
||||
# via -r requirements-dev.in
|
||||
typing-extensions==4.11.0
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
@ -351,22 +371,22 @@ typing-extensions==4.11.0
|
||||
# spandrel
|
||||
# torch
|
||||
# uvicorn
|
||||
urllib3==2.2.1
|
||||
urllib3==2.2.3
|
||||
# via
|
||||
# requests
|
||||
# responses
|
||||
# types-requests
|
||||
uvicorn==0.29.0
|
||||
uvicorn==0.30.6
|
||||
# via imaginAIry (setup.py)
|
||||
watchdog==4.0.0
|
||||
watchdog==5.0.2
|
||||
# via mkdocs
|
||||
wcwidth==0.2.13
|
||||
# via ftfy
|
||||
wheel==0.43.0
|
||||
wheel==0.44.0
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pip-tools
|
||||
zipp==3.18.1
|
||||
zipp==3.20.2
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
|
5
setup.py
5
setup.py
@ -84,7 +84,7 @@ setup(
|
||||
# https://numpy.org/neps/nep-0029-deprecation_policy.html
|
||||
"numpy>=1.22.0",
|
||||
"tqdm>=4.64.0",
|
||||
"diffusers>=0.3.0",
|
||||
"diffusers>=0.30.3",
|
||||
"Pillow>=9.1.0",
|
||||
"psutil>5.7.3",
|
||||
"omegaconf>=2.1.1",
|
||||
@ -110,6 +110,9 @@ setup(
|
||||
"uvicorn>=0.16.0",
|
||||
"spandrel>=0.1.8",
|
||||
# "xformers>=0.0.22; sys_platform!='darwin' and platform_machine!='aarch64'",
|
||||
"optimum-quanto>=0.2.4", # for flux quantization
|
||||
"sentencepiece>=0.2.0",
|
||||
"accelerate>=0.24.0",
|
||||
],
|
||||
# don't specify maximum python versions as it can cause very long dependency resolution issues as the resolver
|
||||
# goes back to older versions of packages that didn't specify a maximum
|
||||
|
@ -85,13 +85,13 @@ def _reset_get_device():
|
||||
get_device.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def filename_base_for_outputs(request):
|
||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_"
|
||||
return filename_base
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def filename_base_for_orig_outputs(request):
|
||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.originalname}_"
|
||||
return filename_base
|
||||
@ -102,7 +102,7 @@ def solver_type(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def mocked_responses():
|
||||
with responses.RequestsMock() as rsps:
|
||||
yield rsps
|
||||
|
@ -21,7 +21,7 @@ def control_img_to_pillow_img(img_t):
|
||||
control_mode_params = list(CONTROL_MODES.items())
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.skip
|
||||
def test_compare_depth_maps(filename_base_for_outputs):
|
||||
sizes = [384, 512, 768]
|
||||
model_types = ISL_PATHS
|
||||
|
@ -10,7 +10,7 @@ from imaginairy.cli.upscale import (
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def mock_pil_save():
|
||||
with patch.object(Image, "save", autospec=True) as mock_save:
|
||||
yield mock_save
|
||||
|
@ -8,13 +8,13 @@ from imaginairy.http_app.app import app
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def mock_generate_image(monkeypatch):
|
||||
fake_generate = mock.MagicMock(return_value=iter("a fake image"))
|
||||
monkeypatch.setattr("imaginairy.http_app.app.generate_image", fake_generate)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.asyncio
|
||||
def test_imagine_endpoint(mock_generate_image):
|
||||
test_input = {"prompt": "test prompt"}
|
||||
|
||||
@ -24,7 +24,7 @@ def test_imagine_endpoint(mock_generate_image):
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_imagine_endpoint(mock_generate_image):
|
||||
test_input = {"text": "a dog"}
|
||||
|
||||
@ -34,7 +34,7 @@ async def test_get_imagine_endpoint(mock_generate_image):
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_imagine_endpoint_mp(mock_generate_image):
|
||||
test_input = {"text": "a dog"}
|
||||
|
||||
|
@ -13,7 +13,7 @@ def _red_b64():
|
||||
return b"iVBORw0KGgoAAAANSUhEUgAAAgAAAAIAAQMAAADOtka5AAAABlBMVEX/AAD///9BHTQRAAAANklEQVR4nO3BAQEAAACCIP+vbkhAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8G4IAAAHSeInwAAAAAElFTkSuQmCC"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def mock_generate_image_b64(monkeypatch, red_b64):
|
||||
fake_generate = mock.MagicMock(return_value=red_b64)
|
||||
monkeypatch.setattr(
|
||||
@ -21,7 +21,7 @@ def mock_generate_image_b64(monkeypatch, red_b64):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_endpoint(mock_generate_image_b64, red_b64):
|
||||
test_input = {
|
||||
"input": {
|
||||
@ -41,7 +41,7 @@ async def test_generate_endpoint(mock_generate_image_b64, red_b64):
|
||||
assert image["blob"] == red_b64.decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_samplers():
|
||||
response = client.get("/api/stablestudio/samplers")
|
||||
assert response.status_code == 200
|
||||
@ -51,7 +51,7 @@ async def test_list_samplers():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models():
|
||||
response = client.get("/api/stablestudio/models")
|
||||
assert response.status_code == 200
|
||||
|
@ -64,7 +64,7 @@ def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat):
|
||||
diff_img.save(f"{filename_base_for_outputs}_diff.png")
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.skip
|
||||
def test_encode_decode_naive_scale(filename_base_for_outputs):
|
||||
model = get_diffusion_model()
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/dog.jpg")
|
||||
|
@ -29,7 +29,7 @@ class MockedMemory:
|
||||
cls.peak_memory = cls.allocated_memory
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def mocked_memory(monkeypatch):
|
||||
monkeypatch.setattr(TorchRAMTracker, "mem_interface", MockedMemory)
|
||||
MockedMemory.allocated_memory = 0
|
||||
|
@ -23,7 +23,7 @@ def create_model_of_n_bytes(n):
|
||||
return DummyMemoryModule(n)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
"model_version",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user