mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: pydantic models for http server working now. Fixes #380
This commit is contained in:
parent
ba51364a73
commit
8243ed616d
1
.gitignore
vendored
1
.gitignore
vendored
@ -30,3 +30,4 @@ tests/vastai_cli.py
|
||||
**/.eggs
|
||||
/img_size_memory_usage.csv
|
||||
/tests/test_cluster_output/
|
||||
/.env
|
||||
|
@ -493,6 +493,10 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
|
||||
|
||||
## ChangeLog
|
||||
|
||||
**13.2.1**
|
||||
- fix: pydantic models for http server working now. Fixes #380
|
||||
- fix: install triton so annoying message is gone
|
||||
|
||||
**13.2.0**
|
||||
- fix: allow tile_mode to be set to True or False for backward compatibility
|
||||
- fix: various pydantic issues have been resolved
|
||||
|
@ -119,5 +119,5 @@ class StableStudioBatchResponse(BaseModel):
|
||||
images: List[StableStudioImage]
|
||||
|
||||
|
||||
StableStudioInput.update_forward_refs()
|
||||
StableStudioImage.update_forward_refs()
|
||||
StableStudioInput.model_rebuild()
|
||||
StableStudioImage.model_rebuild()
|
||||
|
@ -5,7 +5,7 @@ from imaginairy import imagine
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
"""ImaginPrompt to generated image."""
|
||||
"""ImaginePrompt to generated image."""
|
||||
result = next(imagine([prompt]))
|
||||
img = result.images["generated"]
|
||||
img_io = BytesIO()
|
||||
@ -27,7 +27,7 @@ class Base64Bytes(bytes):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
def validate(cls, v, info):
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
|
@ -18,7 +18,6 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
from pydantic_core.core_schema import FieldValidationInfo
|
||||
|
||||
from imaginairy import config
|
||||
|
||||
@ -90,6 +89,12 @@ class LazyLoadingImage:
|
||||
self._load_img()
|
||||
return getattr(self._img, key)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
def _load_img(self):
|
||||
if self._img is None:
|
||||
from PIL import Image, ImageOps
|
||||
@ -192,16 +197,10 @@ class LazyLoadingImage:
|
||||
|
||||
shows filepath or url if available.
|
||||
"""
|
||||
try:
|
||||
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"
|
||||
|
||||
|
||||
#
|
||||
# LazyLoadingImage = Annotated[
|
||||
# _LazyLoadingImage,
|
||||
# AfterValidator(_LazyLoadingImage.validate),
|
||||
# PlainSerializer(lambda i: str(i), return_type=str),
|
||||
# WithJsonSchema({"type": "string"}, mode="serialization"),
|
||||
# ]
|
||||
except Exception as e: # noqa
|
||||
return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>"
|
||||
|
||||
|
||||
class ControlNetInput(BaseModel):
|
||||
@ -218,7 +217,7 @@ class ControlNetInput(BaseModel):
|
||||
# return v
|
||||
|
||||
@field_validator("image_raw")
|
||||
def image_raw_validate(cls, v, info: FieldValidationInfo):
|
||||
def image_raw_validate(cls, v, info: core_schema.FieldValidationInfo):
|
||||
if info.data.get("image") is not None and v is not None:
|
||||
raise ValueError("You cannot specify both image and image_raw")
|
||||
|
||||
@ -245,7 +244,7 @@ class WeightedPrompt(BaseModel):
|
||||
return f"{self.weight}*({self.text})"
|
||||
|
||||
|
||||
class ImaginePrompt(BaseModel):
|
||||
class ImaginePrompt(BaseModel, protected_namespaces=()):
|
||||
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
|
||||
negative_prompt: Optional[List[WeightedPrompt]] = Field(
|
||||
default=None, validate_default=True
|
||||
@ -403,7 +402,7 @@ class ImaginePrompt(BaseModel):
|
||||
return v
|
||||
|
||||
@field_validator("control_inputs", mode="after")
|
||||
def set_image_from_init_image(cls, v, info: FieldValidationInfo):
|
||||
def set_image_from_init_image(cls, v, info: core_schema.FieldValidationInfo):
|
||||
v = v or []
|
||||
for control_input in v:
|
||||
if control_input.image is None and control_input.image_raw is None:
|
||||
@ -411,13 +410,13 @@ class ImaginePrompt(BaseModel):
|
||||
return v
|
||||
|
||||
@field_validator("mask_image")
|
||||
def validate_mask_image(cls, v, info: FieldValidationInfo):
|
||||
def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
|
||||
if v is not None and info.data.get("mask_prompt") is not None:
|
||||
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
|
||||
return v
|
||||
|
||||
@field_validator("mask_prompt", "mask_image", mode="before")
|
||||
def validate_mask_prompt(cls, v, info: FieldValidationInfo):
|
||||
def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
|
||||
if info.data.get("init_image") is None and v:
|
||||
raise ValueError("You must set `init_image` if you want to use a mask")
|
||||
return v
|
||||
@ -441,7 +440,7 @@ class ImaginePrompt(BaseModel):
|
||||
return v
|
||||
|
||||
@field_validator("sampler_type", mode="after")
|
||||
def validate_sampler_type(cls, v, info: FieldValidationInfo):
|
||||
def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo):
|
||||
from imaginairy.samplers import SamplerName
|
||||
|
||||
if v is None:
|
||||
@ -462,7 +461,7 @@ class ImaginePrompt(BaseModel):
|
||||
return v
|
||||
|
||||
@field_validator("steps")
|
||||
def validate_steps(cls, v, info: FieldValidationInfo):
|
||||
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
|
||||
from imaginairy.samplers import SAMPLER_LOOKUP
|
||||
|
||||
if v is None:
|
||||
@ -484,7 +483,7 @@ class ImaginePrompt(BaseModel):
|
||||
return self
|
||||
|
||||
@field_validator("height", "width")
|
||||
def validate_image_size(cls, v, info: FieldValidationInfo):
|
||||
def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
|
||||
from imaginairy.model_manager import get_model_default_image_size
|
||||
|
||||
if v is None:
|
||||
|
@ -10,3 +10,5 @@ pytest-randomly
|
||||
pytest-sugar
|
||||
responses
|
||||
wheel
|
||||
|
||||
-c tests/constraints.txt
|
@ -16,7 +16,7 @@ anyio==3.7.1
|
||||
# via
|
||||
# fastapi
|
||||
# starlette
|
||||
astroid==2.15.6
|
||||
astroid==2.15.8
|
||||
# via pylint
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
@ -42,13 +42,13 @@ click-help-colors==0.9.2
|
||||
# via imaginAIry (setup.py)
|
||||
click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.1.0
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
coverage==7.3.1
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
cycler==0.12.0
|
||||
# via matplotlib
|
||||
diffusers==0.20.2
|
||||
diffusers==0.21.3
|
||||
# via imaginAIry (setup.py)
|
||||
dill==0.3.7
|
||||
# via pylint
|
||||
@ -62,9 +62,9 @@ facexlib==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
fairscale==0.4.13
|
||||
# via imaginAIry (setup.py)
|
||||
fastapi==0.103.1
|
||||
fastapi==0.103.2
|
||||
# via imaginAIry (setup.py)
|
||||
filelock==3.12.3
|
||||
filelock==3.12.4
|
||||
# via
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
@ -77,7 +77,7 @@ frozenlist==1.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2023.9.0
|
||||
fsspec[http]==2023.9.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pytorch-lightning
|
||||
@ -87,7 +87,7 @@ ftfy==6.1.1
|
||||
# open-clip-torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.17.1
|
||||
huggingface-hub==0.17.3
|
||||
# via
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
@ -98,7 +98,7 @@ idna==3.4
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.31.3
|
||||
imageio==2.31.4
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==6.8.0
|
||||
# via diffusers
|
||||
@ -120,10 +120,12 @@ lightning-utilities==0.9.0
|
||||
# via
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
llvmlite==0.40.1
|
||||
llvmlite==0.41.0
|
||||
# via numba
|
||||
matplotlib==3.7.3
|
||||
# via filterpy
|
||||
# via
|
||||
# -c tests/constraints.txt
|
||||
# filterpy
|
||||
mccabe==0.7.0
|
||||
# via
|
||||
# pylama
|
||||
@ -136,10 +138,11 @@ mypy-extensions==1.0.0
|
||||
# via
|
||||
# black
|
||||
# typing-inspect
|
||||
numba==0.57.1
|
||||
numba==0.58.0
|
||||
# via facexlib
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -c tests/constraints.txt
|
||||
# contourpy
|
||||
# diffusers
|
||||
# facexlib
|
||||
@ -159,7 +162,7 @@ omegaconf==2.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
open-clip-torch==2.20.0
|
||||
# via imaginAIry (setup.py)
|
||||
opencv-python==4.8.0.76
|
||||
opencv-python==4.8.1.78
|
||||
# via
|
||||
# facexlib
|
||||
# imaginAIry (setup.py)
|
||||
@ -178,7 +181,7 @@ pathspec==0.11.2
|
||||
# via
|
||||
# black
|
||||
# pycln
|
||||
pillow==10.0.0
|
||||
pillow==10.0.1
|
||||
# via
|
||||
# diffusers
|
||||
# facexlib
|
||||
@ -202,11 +205,11 @@ pycln==2.2.2
|
||||
# via -r requirements-dev.in
|
||||
pycodestyle==2.11.0
|
||||
# via pylama
|
||||
pydantic==2.3.0
|
||||
pydantic==2.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# imaginAIry (setup.py)
|
||||
pydantic-core==2.6.3
|
||||
pydantic-core==2.10.1
|
||||
# via pydantic
|
||||
pydocstyle==6.3.0
|
||||
# via pylama
|
||||
@ -214,7 +217,7 @@ pyflakes==3.1.0
|
||||
# via pylama
|
||||
pylama==8.4.1
|
||||
# via -r requirements-dev.in
|
||||
pylint==2.17.5
|
||||
pylint==2.17.6
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.1.1
|
||||
# via matplotlib
|
||||
@ -257,7 +260,7 @@ requests==2.31.0
|
||||
# transformers
|
||||
responses==0.23.3
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.0.288
|
||||
ruff==0.0.291
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.3.3
|
||||
# via
|
||||
@ -312,7 +315,7 @@ torch==1.13.1
|
||||
# torchvision
|
||||
torchdiffeq==0.2.3
|
||||
# via imaginAIry (setup.py)
|
||||
torchmetrics==1.1.2
|
||||
torchmetrics==1.2.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
@ -330,18 +333,17 @@ tqdm==4.66.1
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
transformers==4.33.1
|
||||
transformers==4.33.3
|
||||
# via imaginAIry (setup.py)
|
||||
typer==0.9.0
|
||||
# via pycln
|
||||
types-pyyaml==6.0.12.11
|
||||
types-pyyaml==6.0.12.12
|
||||
# via responses
|
||||
typing-extensions==4.7.1
|
||||
typing-extensions==4.8.0
|
||||
# via
|
||||
# astroid
|
||||
# black
|
||||
# fastapi
|
||||
# filelock
|
||||
# huggingface-hub
|
||||
# libcst
|
||||
# lightning-utilities
|
||||
@ -355,13 +357,13 @@ typing-extensions==4.7.1
|
||||
# uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via libcst
|
||||
urllib3==2.0.4
|
||||
urllib3==2.0.5
|
||||
# via
|
||||
# requests
|
||||
# responses
|
||||
uvicorn==0.23.2
|
||||
# via imaginAIry (setup.py)
|
||||
wcwidth==0.2.6
|
||||
wcwidth==0.2.7
|
||||
# via ftfy
|
||||
wheel==0.41.2
|
||||
# via -r requirements-dev.in
|
||||
@ -369,5 +371,5 @@ wrapt==1.15.0
|
||||
# via astroid
|
||||
yarl==1.9.2
|
||||
# via aiohttp
|
||||
zipp==3.16.2
|
||||
zipp==3.17.0
|
||||
# via importlib-metadata
|
||||
|
3
setup.py
3
setup.py
@ -103,9 +103,10 @@ setup(
|
||||
"scipy<1.11",
|
||||
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
|
||||
"torchdiffeq>=0.2.0",
|
||||
"transformers>=4.19.2",
|
||||
"torchmetrics>=0.6.0",
|
||||
"torchvision>=0.13.1",
|
||||
"transformers>=4.19.2",
|
||||
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",
|
||||
"kornia>=0.6",
|
||||
"uvicorn>=0.16.0",
|
||||
"xformers>=0.0.16; sys_platform!='darwin' and platform_machine!='aarch64'",
|
||||
|
3
tests/constraints.txt
Normal file
3
tests/constraints.txt
Normal file
@ -0,0 +1,3 @@
|
||||
# held back for python 3.8 compatability
|
||||
matplotlib<3.8.0
|
||||
numpy<1.25.0
|
@ -101,3 +101,14 @@ def test_image_deserialization(red_path, red_url):
|
||||
for row in rows:
|
||||
obj = TestModel.model_validate(row)
|
||||
assert obj.header_img.size == (512, 512)
|
||||
|
||||
|
||||
def test_image_state(red_path):
|
||||
"""I dont remember what this fixes. Maybe the ability of pydantic to copy an object?."""
|
||||
img = LazyLoadingImage(filepath=red_path)
|
||||
|
||||
# bypass init
|
||||
img2 = LazyLoadingImage.__new__(LazyLoadingImage)
|
||||
img2.__setstate__(img.__getstate__())
|
||||
|
||||
assert repr(img) == repr(img2)
|
||||
|
Loading…
Reference in New Issue
Block a user