mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: pydantic models for http server working now. Fixes #380
This commit is contained in:
parent
ba51364a73
commit
8243ed616d
1
.gitignore
vendored
1
.gitignore
vendored
@ -30,3 +30,4 @@ tests/vastai_cli.py
|
|||||||
**/.eggs
|
**/.eggs
|
||||||
/img_size_memory_usage.csv
|
/img_size_memory_usage.csv
|
||||||
/tests/test_cluster_output/
|
/tests/test_cluster_output/
|
||||||
|
/.env
|
||||||
|
@ -493,6 +493,10 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
|
|||||||
|
|
||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
|
||||||
|
**13.2.1**
|
||||||
|
- fix: pydantic models for http server working now. Fixes #380
|
||||||
|
- fix: install triton so annoying message is gone
|
||||||
|
|
||||||
**13.2.0**
|
**13.2.0**
|
||||||
- fix: allow tile_mode to be set to True or False for backward compatibility
|
- fix: allow tile_mode to be set to True or False for backward compatibility
|
||||||
- fix: various pydantic issues have been resolved
|
- fix: various pydantic issues have been resolved
|
||||||
|
@ -119,5 +119,5 @@ class StableStudioBatchResponse(BaseModel):
|
|||||||
images: List[StableStudioImage]
|
images: List[StableStudioImage]
|
||||||
|
|
||||||
|
|
||||||
StableStudioInput.update_forward_refs()
|
StableStudioInput.model_rebuild()
|
||||||
StableStudioImage.update_forward_refs()
|
StableStudioImage.model_rebuild()
|
||||||
|
@ -5,7 +5,7 @@ from imaginairy import imagine
|
|||||||
|
|
||||||
|
|
||||||
def generate_image(prompt):
|
def generate_image(prompt):
|
||||||
"""ImaginPrompt to generated image."""
|
"""ImaginePrompt to generated image."""
|
||||||
result = next(imagine([prompt]))
|
result = next(imagine([prompt]))
|
||||||
img = result.images["generated"]
|
img = result.images["generated"]
|
||||||
img_io = BytesIO()
|
img_io = BytesIO()
|
||||||
@ -27,7 +27,7 @@ class Base64Bytes(bytes):
|
|||||||
yield cls.validate
|
yield cls.validate
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, v):
|
def validate(cls, v, info):
|
||||||
if isinstance(v, bytes):
|
if isinstance(v, bytes):
|
||||||
return v
|
return v
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
|
@ -18,7 +18,6 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
from pydantic_core.core_schema import FieldValidationInfo
|
|
||||||
|
|
||||||
from imaginairy import config
|
from imaginairy import config
|
||||||
|
|
||||||
@ -90,6 +89,12 @@ class LazyLoadingImage:
|
|||||||
self._load_img()
|
self._load_img()
|
||||||
return getattr(self._img, key)
|
return getattr(self._img, key)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.__dict__
|
||||||
|
|
||||||
def _load_img(self):
|
def _load_img(self):
|
||||||
if self._img is None:
|
if self._img is None:
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
@ -192,16 +197,10 @@ class LazyLoadingImage:
|
|||||||
|
|
||||||
shows filepath or url if available.
|
shows filepath or url if available.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"
|
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"
|
||||||
|
except Exception as e: # noqa
|
||||||
|
return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>"
|
||||||
#
|
|
||||||
# LazyLoadingImage = Annotated[
|
|
||||||
# _LazyLoadingImage,
|
|
||||||
# AfterValidator(_LazyLoadingImage.validate),
|
|
||||||
# PlainSerializer(lambda i: str(i), return_type=str),
|
|
||||||
# WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
||||||
# ]
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetInput(BaseModel):
|
class ControlNetInput(BaseModel):
|
||||||
@ -218,7 +217,7 @@ class ControlNetInput(BaseModel):
|
|||||||
# return v
|
# return v
|
||||||
|
|
||||||
@field_validator("image_raw")
|
@field_validator("image_raw")
|
||||||
def image_raw_validate(cls, v, info: FieldValidationInfo):
|
def image_raw_validate(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
if info.data.get("image") is not None and v is not None:
|
if info.data.get("image") is not None and v is not None:
|
||||||
raise ValueError("You cannot specify both image and image_raw")
|
raise ValueError("You cannot specify both image and image_raw")
|
||||||
|
|
||||||
@ -245,7 +244,7 @@ class WeightedPrompt(BaseModel):
|
|||||||
return f"{self.weight}*({self.text})"
|
return f"{self.weight}*({self.text})"
|
||||||
|
|
||||||
|
|
||||||
class ImaginePrompt(BaseModel):
|
class ImaginePrompt(BaseModel, protected_namespaces=()):
|
||||||
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
|
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
|
||||||
negative_prompt: Optional[List[WeightedPrompt]] = Field(
|
negative_prompt: Optional[List[WeightedPrompt]] = Field(
|
||||||
default=None, validate_default=True
|
default=None, validate_default=True
|
||||||
@ -403,7 +402,7 @@ class ImaginePrompt(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("control_inputs", mode="after")
|
@field_validator("control_inputs", mode="after")
|
||||||
def set_image_from_init_image(cls, v, info: FieldValidationInfo):
|
def set_image_from_init_image(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
v = v or []
|
v = v or []
|
||||||
for control_input in v:
|
for control_input in v:
|
||||||
if control_input.image is None and control_input.image_raw is None:
|
if control_input.image is None and control_input.image_raw is None:
|
||||||
@ -411,13 +410,13 @@ class ImaginePrompt(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("mask_image")
|
@field_validator("mask_image")
|
||||||
def validate_mask_image(cls, v, info: FieldValidationInfo):
|
def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
if v is not None and info.data.get("mask_prompt") is not None:
|
if v is not None and info.data.get("mask_prompt") is not None:
|
||||||
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
|
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("mask_prompt", "mask_image", mode="before")
|
@field_validator("mask_prompt", "mask_image", mode="before")
|
||||||
def validate_mask_prompt(cls, v, info: FieldValidationInfo):
|
def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
if info.data.get("init_image") is None and v:
|
if info.data.get("init_image") is None and v:
|
||||||
raise ValueError("You must set `init_image` if you want to use a mask")
|
raise ValueError("You must set `init_image` if you want to use a mask")
|
||||||
return v
|
return v
|
||||||
@ -441,7 +440,7 @@ class ImaginePrompt(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("sampler_type", mode="after")
|
@field_validator("sampler_type", mode="after")
|
||||||
def validate_sampler_type(cls, v, info: FieldValidationInfo):
|
def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
from imaginairy.samplers import SamplerName
|
from imaginairy.samplers import SamplerName
|
||||||
|
|
||||||
if v is None:
|
if v is None:
|
||||||
@ -462,7 +461,7 @@ class ImaginePrompt(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("steps")
|
@field_validator("steps")
|
||||||
def validate_steps(cls, v, info: FieldValidationInfo):
|
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
from imaginairy.samplers import SAMPLER_LOOKUP
|
from imaginairy.samplers import SAMPLER_LOOKUP
|
||||||
|
|
||||||
if v is None:
|
if v is None:
|
||||||
@ -484,7 +483,7 @@ class ImaginePrompt(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("height", "width")
|
@field_validator("height", "width")
|
||||||
def validate_image_size(cls, v, info: FieldValidationInfo):
|
def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
|
||||||
from imaginairy.model_manager import get_model_default_image_size
|
from imaginairy.model_manager import get_model_default_image_size
|
||||||
|
|
||||||
if v is None:
|
if v is None:
|
||||||
|
@ -10,3 +10,5 @@ pytest-randomly
|
|||||||
pytest-sugar
|
pytest-sugar
|
||||||
responses
|
responses
|
||||||
wheel
|
wheel
|
||||||
|
|
||||||
|
-c tests/constraints.txt
|
@ -16,7 +16,7 @@ anyio==3.7.1
|
|||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
# starlette
|
# starlette
|
||||||
astroid==2.15.6
|
astroid==2.15.8
|
||||||
# via pylint
|
# via pylint
|
||||||
async-timeout==4.0.3
|
async-timeout==4.0.3
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
@ -42,13 +42,13 @@ click-help-colors==0.9.2
|
|||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
click-shell==2.1
|
click-shell==2.1
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
contourpy==1.1.0
|
contourpy==1.1.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
coverage==7.3.1
|
coverage==7.3.1
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
cycler==0.11.0
|
cycler==0.12.0
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
diffusers==0.20.2
|
diffusers==0.21.3
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
dill==0.3.7
|
dill==0.3.7
|
||||||
# via pylint
|
# via pylint
|
||||||
@ -62,9 +62,9 @@ facexlib==0.3.0
|
|||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
fairscale==0.4.13
|
fairscale==0.4.13
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
fastapi==0.103.1
|
fastapi==0.103.2
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
filelock==3.12.3
|
filelock==3.12.4
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
@ -77,7 +77,7 @@ frozenlist==1.4.0
|
|||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
fsspec[http]==2023.9.0
|
fsspec[http]==2023.9.2
|
||||||
# via
|
# via
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
@ -87,7 +87,7 @@ ftfy==6.1.1
|
|||||||
# open-clip-torch
|
# open-clip-torch
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
huggingface-hub==0.17.1
|
huggingface-hub==0.17.3
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# open-clip-torch
|
# open-clip-torch
|
||||||
@ -98,7 +98,7 @@ idna==3.4
|
|||||||
# anyio
|
# anyio
|
||||||
# requests
|
# requests
|
||||||
# yarl
|
# yarl
|
||||||
imageio==2.31.3
|
imageio==2.31.4
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
importlib-metadata==6.8.0
|
importlib-metadata==6.8.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
@ -120,10 +120,12 @@ lightning-utilities==0.9.0
|
|||||||
# via
|
# via
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
# torchmetrics
|
# torchmetrics
|
||||||
llvmlite==0.40.1
|
llvmlite==0.41.0
|
||||||
# via numba
|
# via numba
|
||||||
matplotlib==3.7.3
|
matplotlib==3.7.3
|
||||||
# via filterpy
|
# via
|
||||||
|
# -c tests/constraints.txt
|
||||||
|
# filterpy
|
||||||
mccabe==0.7.0
|
mccabe==0.7.0
|
||||||
# via
|
# via
|
||||||
# pylama
|
# pylama
|
||||||
@ -136,10 +138,11 @@ mypy-extensions==1.0.0
|
|||||||
# via
|
# via
|
||||||
# black
|
# black
|
||||||
# typing-inspect
|
# typing-inspect
|
||||||
numba==0.57.1
|
numba==0.58.0
|
||||||
# via facexlib
|
# via facexlib
|
||||||
numpy==1.24.4
|
numpy==1.24.4
|
||||||
# via
|
# via
|
||||||
|
# -c tests/constraints.txt
|
||||||
# contourpy
|
# contourpy
|
||||||
# diffusers
|
# diffusers
|
||||||
# facexlib
|
# facexlib
|
||||||
@ -159,7 +162,7 @@ omegaconf==2.3.0
|
|||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
open-clip-torch==2.20.0
|
open-clip-torch==2.20.0
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
opencv-python==4.8.0.76
|
opencv-python==4.8.1.78
|
||||||
# via
|
# via
|
||||||
# facexlib
|
# facexlib
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
@ -178,7 +181,7 @@ pathspec==0.11.2
|
|||||||
# via
|
# via
|
||||||
# black
|
# black
|
||||||
# pycln
|
# pycln
|
||||||
pillow==10.0.0
|
pillow==10.0.1
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# facexlib
|
# facexlib
|
||||||
@ -202,11 +205,11 @@ pycln==2.2.2
|
|||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
pycodestyle==2.11.0
|
pycodestyle==2.11.0
|
||||||
# via pylama
|
# via pylama
|
||||||
pydantic==2.3.0
|
pydantic==2.4.2
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
pydantic-core==2.6.3
|
pydantic-core==2.10.1
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pydocstyle==6.3.0
|
pydocstyle==6.3.0
|
||||||
# via pylama
|
# via pylama
|
||||||
@ -214,7 +217,7 @@ pyflakes==3.1.0
|
|||||||
# via pylama
|
# via pylama
|
||||||
pylama==8.4.1
|
pylama==8.4.1
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
pylint==2.17.5
|
pylint==2.17.6
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
pyparsing==3.1.1
|
pyparsing==3.1.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
@ -257,7 +260,7 @@ requests==2.31.0
|
|||||||
# transformers
|
# transformers
|
||||||
responses==0.23.3
|
responses==0.23.3
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
ruff==0.0.288
|
ruff==0.0.291
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
safetensors==0.3.3
|
safetensors==0.3.3
|
||||||
# via
|
# via
|
||||||
@ -312,7 +315,7 @@ torch==1.13.1
|
|||||||
# torchvision
|
# torchvision
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
torchmetrics==1.1.2
|
torchmetrics==1.2.0
|
||||||
# via
|
# via
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
@ -330,18 +333,17 @@ tqdm==4.66.1
|
|||||||
# open-clip-torch
|
# open-clip-torch
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
# transformers
|
# transformers
|
||||||
transformers==4.33.1
|
transformers==4.33.3
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
typer==0.9.0
|
typer==0.9.0
|
||||||
# via pycln
|
# via pycln
|
||||||
types-pyyaml==6.0.12.11
|
types-pyyaml==6.0.12.12
|
||||||
# via responses
|
# via responses
|
||||||
typing-extensions==4.7.1
|
typing-extensions==4.8.0
|
||||||
# via
|
# via
|
||||||
# astroid
|
# astroid
|
||||||
# black
|
# black
|
||||||
# fastapi
|
# fastapi
|
||||||
# filelock
|
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# libcst
|
# libcst
|
||||||
# lightning-utilities
|
# lightning-utilities
|
||||||
@ -355,13 +357,13 @@ typing-extensions==4.7.1
|
|||||||
# uvicorn
|
# uvicorn
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
# via libcst
|
# via libcst
|
||||||
urllib3==2.0.4
|
urllib3==2.0.5
|
||||||
# via
|
# via
|
||||||
# requests
|
# requests
|
||||||
# responses
|
# responses
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
wcwidth==0.2.6
|
wcwidth==0.2.7
|
||||||
# via ftfy
|
# via ftfy
|
||||||
wheel==0.41.2
|
wheel==0.41.2
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
@ -369,5 +371,5 @@ wrapt==1.15.0
|
|||||||
# via astroid
|
# via astroid
|
||||||
yarl==1.9.2
|
yarl==1.9.2
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
zipp==3.16.2
|
zipp==3.17.0
|
||||||
# via importlib-metadata
|
# via importlib-metadata
|
||||||
|
3
setup.py
3
setup.py
@ -103,9 +103,10 @@ setup(
|
|||||||
"scipy<1.11",
|
"scipy<1.11",
|
||||||
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
|
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
|
||||||
"torchdiffeq>=0.2.0",
|
"torchdiffeq>=0.2.0",
|
||||||
"transformers>=4.19.2",
|
|
||||||
"torchmetrics>=0.6.0",
|
"torchmetrics>=0.6.0",
|
||||||
"torchvision>=0.13.1",
|
"torchvision>=0.13.1",
|
||||||
|
"transformers>=4.19.2",
|
||||||
|
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",
|
||||||
"kornia>=0.6",
|
"kornia>=0.6",
|
||||||
"uvicorn>=0.16.0",
|
"uvicorn>=0.16.0",
|
||||||
"xformers>=0.0.16; sys_platform!='darwin' and platform_machine!='aarch64'",
|
"xformers>=0.0.16; sys_platform!='darwin' and platform_machine!='aarch64'",
|
||||||
|
3
tests/constraints.txt
Normal file
3
tests/constraints.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# held back for python 3.8 compatability
|
||||||
|
matplotlib<3.8.0
|
||||||
|
numpy<1.25.0
|
@ -101,3 +101,14 @@ def test_image_deserialization(red_path, red_url):
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
obj = TestModel.model_validate(row)
|
obj = TestModel.model_validate(row)
|
||||||
assert obj.header_img.size == (512, 512)
|
assert obj.header_img.size == (512, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_state(red_path):
|
||||||
|
"""I dont remember what this fixes. Maybe the ability of pydantic to copy an object?."""
|
||||||
|
img = LazyLoadingImage(filepath=red_path)
|
||||||
|
|
||||||
|
# bypass init
|
||||||
|
img2 = LazyLoadingImage.__new__(LazyLoadingImage)
|
||||||
|
img2.__setstate__(img.__getstate__())
|
||||||
|
|
||||||
|
assert repr(img) == repr(img2)
|
||||||
|
Loading…
Reference in New Issue
Block a user