fix: pydantic models for http server working now. Fixes #380

This commit is contained in:
Bryce 2023-09-28 23:32:30 -07:00 committed by Bryce Drennan
parent ba51364a73
commit 8243ed616d
10 changed files with 73 additions and 50 deletions

1
.gitignore vendored
View File

@ -30,3 +30,4 @@ tests/vastai_cli.py
**/.eggs **/.eggs
/img_size_memory_usage.csv /img_size_memory_usage.csv
/tests/test_cluster_output/ /tests/test_cluster_output/
/.env

View File

@ -493,6 +493,10 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
## ChangeLog ## ChangeLog
**13.2.1**
- fix: pydantic models for http server working now. Fixes #380
- fix: install triton so annoying message is gone
**13.2.0** **13.2.0**
- fix: allow tile_mode to be set to True or False for backward compatibility - fix: allow tile_mode to be set to True or False for backward compatibility
- fix: various pydantic issues have been resolved - fix: various pydantic issues have been resolved

View File

@ -119,5 +119,5 @@ class StableStudioBatchResponse(BaseModel):
images: List[StableStudioImage] images: List[StableStudioImage]
StableStudioInput.update_forward_refs() StableStudioInput.model_rebuild()
StableStudioImage.update_forward_refs() StableStudioImage.model_rebuild()

View File

@ -5,7 +5,7 @@ from imaginairy import imagine
def generate_image(prompt): def generate_image(prompt):
"""ImaginPrompt to generated image.""" """ImaginePrompt to generated image."""
result = next(imagine([prompt])) result = next(imagine([prompt]))
img = result.images["generated"] img = result.images["generated"]
img_io = BytesIO() img_io = BytesIO()
@ -27,7 +27,7 @@ class Base64Bytes(bytes):
yield cls.validate yield cls.validate
@classmethod @classmethod
def validate(cls, v): def validate(cls, v, info):
if isinstance(v, bytes): if isinstance(v, bytes):
return v return v
if isinstance(v, str): if isinstance(v, str):

View File

@ -18,7 +18,6 @@ from pydantic import (
model_validator, model_validator,
) )
from pydantic_core import core_schema from pydantic_core import core_schema
from pydantic_core.core_schema import FieldValidationInfo
from imaginairy import config from imaginairy import config
@ -90,6 +89,12 @@ class LazyLoadingImage:
self._load_img() self._load_img()
return getattr(self._img, key) return getattr(self._img, key)
def __setstate__(self, state):
self.__dict__.update(state)
def __getstate__(self):
return self.__dict__
def _load_img(self): def _load_img(self):
if self._img is None: if self._img is None:
from PIL import Image, ImageOps from PIL import Image, ImageOps
@ -192,16 +197,10 @@ class LazyLoadingImage:
shows filepath or url if available. shows filepath or url if available.
""" """
try:
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>" return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"
except Exception as e: # noqa
return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>"
#
# LazyLoadingImage = Annotated[
# _LazyLoadingImage,
# AfterValidator(_LazyLoadingImage.validate),
# PlainSerializer(lambda i: str(i), return_type=str),
# WithJsonSchema({"type": "string"}, mode="serialization"),
# ]
class ControlNetInput(BaseModel): class ControlNetInput(BaseModel):
@ -218,7 +217,7 @@ class ControlNetInput(BaseModel):
# return v # return v
@field_validator("image_raw") @field_validator("image_raw")
def image_raw_validate(cls, v, info: FieldValidationInfo): def image_raw_validate(cls, v, info: core_schema.FieldValidationInfo):
if info.data.get("image") is not None and v is not None: if info.data.get("image") is not None and v is not None:
raise ValueError("You cannot specify both image and image_raw") raise ValueError("You cannot specify both image and image_raw")
@ -245,7 +244,7 @@ class WeightedPrompt(BaseModel):
return f"{self.weight}*({self.text})" return f"{self.weight}*({self.text})"
class ImaginePrompt(BaseModel): class ImaginePrompt(BaseModel, protected_namespaces=()):
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True) prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
negative_prompt: Optional[List[WeightedPrompt]] = Field( negative_prompt: Optional[List[WeightedPrompt]] = Field(
default=None, validate_default=True default=None, validate_default=True
@ -403,7 +402,7 @@ class ImaginePrompt(BaseModel):
return v return v
@field_validator("control_inputs", mode="after") @field_validator("control_inputs", mode="after")
def set_image_from_init_image(cls, v, info: FieldValidationInfo): def set_image_from_init_image(cls, v, info: core_schema.FieldValidationInfo):
v = v or [] v = v or []
for control_input in v: for control_input in v:
if control_input.image is None and control_input.image_raw is None: if control_input.image is None and control_input.image_raw is None:
@ -411,13 +410,13 @@ class ImaginePrompt(BaseModel):
return v return v
@field_validator("mask_image") @field_validator("mask_image")
def validate_mask_image(cls, v, info: FieldValidationInfo): def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
if v is not None and info.data.get("mask_prompt") is not None: if v is not None and info.data.get("mask_prompt") is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`") raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
return v return v
@field_validator("mask_prompt", "mask_image", mode="before") @field_validator("mask_prompt", "mask_image", mode="before")
def validate_mask_prompt(cls, v, info: FieldValidationInfo): def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
if info.data.get("init_image") is None and v: if info.data.get("init_image") is None and v:
raise ValueError("You must set `init_image` if you want to use a mask") raise ValueError("You must set `init_image` if you want to use a mask")
return v return v
@ -441,7 +440,7 @@ class ImaginePrompt(BaseModel):
return v return v
@field_validator("sampler_type", mode="after") @field_validator("sampler_type", mode="after")
def validate_sampler_type(cls, v, info: FieldValidationInfo): def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SamplerName from imaginairy.samplers import SamplerName
if v is None: if v is None:
@ -462,7 +461,7 @@ class ImaginePrompt(BaseModel):
return v return v
@field_validator("steps") @field_validator("steps")
def validate_steps(cls, v, info: FieldValidationInfo): def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SAMPLER_LOOKUP from imaginairy.samplers import SAMPLER_LOOKUP
if v is None: if v is None:
@ -484,7 +483,7 @@ class ImaginePrompt(BaseModel):
return self return self
@field_validator("height", "width") @field_validator("height", "width")
def validate_image_size(cls, v, info: FieldValidationInfo): def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.model_manager import get_model_default_image_size from imaginairy.model_manager import get_model_default_image_size
if v is None: if v is None:

View File

@ -10,3 +10,5 @@ pytest-randomly
pytest-sugar pytest-sugar
responses responses
wheel wheel
-c tests/constraints.txt

View File

@ -16,7 +16,7 @@ anyio==3.7.1
# via # via
# fastapi # fastapi
# starlette # starlette
astroid==2.15.6 astroid==2.15.8
# via pylint # via pylint
async-timeout==4.0.3 async-timeout==4.0.3
# via aiohttp # via aiohttp
@ -42,13 +42,13 @@ click-help-colors==0.9.2
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
click-shell==2.1 click-shell==2.1
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
contourpy==1.1.0 contourpy==1.1.1
# via matplotlib # via matplotlib
coverage==7.3.1 coverage==7.3.1
# via -r requirements-dev.in # via -r requirements-dev.in
cycler==0.11.0 cycler==0.12.0
# via matplotlib # via matplotlib
diffusers==0.20.2 diffusers==0.21.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
dill==0.3.7 dill==0.3.7
# via pylint # via pylint
@ -62,9 +62,9 @@ facexlib==0.3.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
fairscale==0.4.13 fairscale==0.4.13
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
fastapi==0.103.1 fastapi==0.103.2
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
filelock==3.12.3 filelock==3.12.4
# via # via
# diffusers # diffusers
# huggingface-hub # huggingface-hub
@ -77,7 +77,7 @@ frozenlist==1.4.0
# via # via
# aiohttp # aiohttp
# aiosignal # aiosignal
fsspec[http]==2023.9.0 fsspec[http]==2023.9.2
# via # via
# huggingface-hub # huggingface-hub
# pytorch-lightning # pytorch-lightning
@ -87,7 +87,7 @@ ftfy==6.1.1
# open-clip-torch # open-clip-torch
h11==0.14.0 h11==0.14.0
# via uvicorn # via uvicorn
huggingface-hub==0.17.1 huggingface-hub==0.17.3
# via # via
# diffusers # diffusers
# open-clip-torch # open-clip-torch
@ -98,7 +98,7 @@ idna==3.4
# anyio # anyio
# requests # requests
# yarl # yarl
imageio==2.31.3 imageio==2.31.4
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
importlib-metadata==6.8.0 importlib-metadata==6.8.0
# via diffusers # via diffusers
@ -120,10 +120,12 @@ lightning-utilities==0.9.0
# via # via
# pytorch-lightning # pytorch-lightning
# torchmetrics # torchmetrics
llvmlite==0.40.1 llvmlite==0.41.0
# via numba # via numba
matplotlib==3.7.3 matplotlib==3.7.3
# via filterpy # via
# -c tests/constraints.txt
# filterpy
mccabe==0.7.0 mccabe==0.7.0
# via # via
# pylama # pylama
@ -136,10 +138,11 @@ mypy-extensions==1.0.0
# via # via
# black # black
# typing-inspect # typing-inspect
numba==0.57.1 numba==0.58.0
# via facexlib # via facexlib
numpy==1.24.4 numpy==1.24.4
# via # via
# -c tests/constraints.txt
# contourpy # contourpy
# diffusers # diffusers
# facexlib # facexlib
@ -159,7 +162,7 @@ omegaconf==2.3.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
open-clip-torch==2.20.0 open-clip-torch==2.20.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
opencv-python==4.8.0.76 opencv-python==4.8.1.78
# via # via
# facexlib # facexlib
# imaginAIry (setup.py) # imaginAIry (setup.py)
@ -178,7 +181,7 @@ pathspec==0.11.2
# via # via
# black # black
# pycln # pycln
pillow==10.0.0 pillow==10.0.1
# via # via
# diffusers # diffusers
# facexlib # facexlib
@ -202,11 +205,11 @@ pycln==2.2.2
# via -r requirements-dev.in # via -r requirements-dev.in
pycodestyle==2.11.0 pycodestyle==2.11.0
# via pylama # via pylama
pydantic==2.3.0 pydantic==2.4.2
# via # via
# fastapi # fastapi
# imaginAIry (setup.py) # imaginAIry (setup.py)
pydantic-core==2.6.3 pydantic-core==2.10.1
# via pydantic # via pydantic
pydocstyle==6.3.0 pydocstyle==6.3.0
# via pylama # via pylama
@ -214,7 +217,7 @@ pyflakes==3.1.0
# via pylama # via pylama
pylama==8.4.1 pylama==8.4.1
# via -r requirements-dev.in # via -r requirements-dev.in
pylint==2.17.5 pylint==2.17.6
# via -r requirements-dev.in # via -r requirements-dev.in
pyparsing==3.1.1 pyparsing==3.1.1
# via matplotlib # via matplotlib
@ -257,7 +260,7 @@ requests==2.31.0
# transformers # transformers
responses==0.23.3 responses==0.23.3
# via -r requirements-dev.in # via -r requirements-dev.in
ruff==0.0.288 ruff==0.0.291
# via -r requirements-dev.in # via -r requirements-dev.in
safetensors==0.3.3 safetensors==0.3.3
# via # via
@ -312,7 +315,7 @@ torch==1.13.1
# torchvision # torchvision
torchdiffeq==0.2.3 torchdiffeq==0.2.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
torchmetrics==1.1.2 torchmetrics==1.2.0
# via # via
# imaginAIry (setup.py) # imaginAIry (setup.py)
# pytorch-lightning # pytorch-lightning
@ -330,18 +333,17 @@ tqdm==4.66.1
# open-clip-torch # open-clip-torch
# pytorch-lightning # pytorch-lightning
# transformers # transformers
transformers==4.33.1 transformers==4.33.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
typer==0.9.0 typer==0.9.0
# via pycln # via pycln
types-pyyaml==6.0.12.11 types-pyyaml==6.0.12.12
# via responses # via responses
typing-extensions==4.7.1 typing-extensions==4.8.0
# via # via
# astroid # astroid
# black # black
# fastapi # fastapi
# filelock
# huggingface-hub # huggingface-hub
# libcst # libcst
# lightning-utilities # lightning-utilities
@ -355,13 +357,13 @@ typing-extensions==4.7.1
# uvicorn # uvicorn
typing-inspect==0.9.0 typing-inspect==0.9.0
# via libcst # via libcst
urllib3==2.0.4 urllib3==2.0.5
# via # via
# requests # requests
# responses # responses
uvicorn==0.23.2 uvicorn==0.23.2
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
wcwidth==0.2.6 wcwidth==0.2.7
# via ftfy # via ftfy
wheel==0.41.2 wheel==0.41.2
# via -r requirements-dev.in # via -r requirements-dev.in
@ -369,5 +371,5 @@ wrapt==1.15.0
# via astroid # via astroid
yarl==1.9.2 yarl==1.9.2
# via aiohttp # via aiohttp
zipp==3.16.2 zipp==3.17.0
# via importlib-metadata # via importlib-metadata

View File

@ -103,9 +103,10 @@ setup(
"scipy<1.11", "scipy<1.11",
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip "timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
"torchdiffeq>=0.2.0", "torchdiffeq>=0.2.0",
"transformers>=4.19.2",
"torchmetrics>=0.6.0", "torchmetrics>=0.6.0",
"torchvision>=0.13.1", "torchvision>=0.13.1",
"transformers>=4.19.2",
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",
"kornia>=0.6", "kornia>=0.6",
"uvicorn>=0.16.0", "uvicorn>=0.16.0",
"xformers>=0.0.16; sys_platform!='darwin' and platform_machine!='aarch64'", "xformers>=0.0.16; sys_platform!='darwin' and platform_machine!='aarch64'",

3
tests/constraints.txt Normal file
View File

@ -0,0 +1,3 @@
# held back for python 3.8 compatability
matplotlib<3.8.0
numpy<1.25.0

View File

@ -101,3 +101,14 @@ def test_image_deserialization(red_path, red_url):
for row in rows: for row in rows:
obj = TestModel.model_validate(row) obj = TestModel.model_validate(row)
assert obj.header_img.size == (512, 512) assert obj.header_img.size == (512, 512)
def test_image_state(red_path):
"""I dont remember what this fixes. Maybe the ability of pydantic to copy an object?."""
img = LazyLoadingImage(filepath=red_path)
# bypass init
img2 = LazyLoadingImage.__new__(LazyLoadingImage)
img2.__setstate__(img.__getstate__())
assert repr(img) == repr(img2)