feature: API support for StableStudio
parent
e53459a50a
commit
8e28a2ed02
@ -0,0 +1,114 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from imaginairy.http.utils import Base64Bytes
|
||||
|
||||
|
||||
class ImagineWebPrompt(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
prompt: Optional[str]
|
||||
negative_prompt: Optional[str] = None
|
||||
prompt_strength: Optional[float] = None
|
||||
init_image: Optional[Base64Bytes] = None
|
||||
init_image_strength: Optional[float] = None
|
||||
# control_inputs: Optional[List[ControlInput]] = None
|
||||
mask_prompt: Optional[str] = None
|
||||
mask_image: Optional[Base64Bytes] = None
|
||||
mask_mode: str = "replace"
|
||||
mask_modify_original: bool = True
|
||||
outpaint: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
steps: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
upscale: bool = False
|
||||
fix_faces: bool = False
|
||||
fix_faces_fidelity: float = 0.2
|
||||
sampler_type: Optional[str] = None
|
||||
conditioning: Optional[str] = None
|
||||
tile_mode: str = ""
|
||||
allow_compose_phase: bool = True
|
||||
model: Optional[str] = None
|
||||
model_config_path: Optional[str] = None
|
||||
is_intermediate: bool = False
|
||||
collect_progress_latents: bool = False
|
||||
caption_text: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_stable_studio_input(cls, stable_input):
|
||||
positive_prompt = stable_input.prompts[0].text
|
||||
negative_prompt = (
|
||||
stable_input.prompts[1].text if len(stable_input.prompts) > 1 else None
|
||||
)
|
||||
|
||||
init_image = None
|
||||
init_image_strength = None
|
||||
if stable_input.initial_image:
|
||||
init_image = stable_input.initial_image.blob
|
||||
init_image_strength = stable_input.initial_image.weight
|
||||
|
||||
mask_image = stable_input.mask_image.blob if stable_input.mask_image else None
|
||||
|
||||
sampler_type = stable_input.sampler.id if stable_input.sampler else None
|
||||
|
||||
return cls(
|
||||
prompt=positive_prompt,
|
||||
prompt_strength=stable_input.cfg_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
model=stable_input.model,
|
||||
sampler_type=sampler_type,
|
||||
seed=stable_input.seed,
|
||||
steps=stable_input.steps,
|
||||
height=stable_input.height,
|
||||
width=stable_input.width,
|
||||
init_image=init_image,
|
||||
init_image_strength=init_image_strength,
|
||||
mask_image=mask_image,
|
||||
mask_mode="keep",
|
||||
)
|
||||
|
||||
def to_imagine_prompt(self):
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy import ImaginePrompt
|
||||
|
||||
imagine_prompt = ImaginePrompt(
|
||||
prompt=self.prompt,
|
||||
negative_prompt=self.negative_prompt,
|
||||
prompt_strength=self.prompt_strength,
|
||||
init_image=Image.open(BytesIO(self.init_image))
|
||||
if self.init_image
|
||||
else None,
|
||||
init_image_strength=self.init_image_strength,
|
||||
# control_inputs=self.control_inputs, # Uncomment this if the control_inputs field exists in ImagineWebPrompt
|
||||
mask_prompt=self.mask_prompt,
|
||||
mask_image=Image.open(BytesIO(self.mask_image))
|
||||
if self.mask_image
|
||||
else None,
|
||||
mask_mode=self.mask_mode,
|
||||
mask_modify_original=self.mask_modify_original,
|
||||
outpaint=self.outpaint,
|
||||
seed=self.seed,
|
||||
steps=self.steps,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
upscale=self.upscale,
|
||||
fix_faces=self.fix_faces,
|
||||
fix_faces_fidelity=self.fix_faces_fidelity,
|
||||
sampler_type=self.sampler_type,
|
||||
conditioning=self.conditioning,
|
||||
tile_mode=self.tile_mode,
|
||||
allow_compose_phase=self.allow_compose_phase,
|
||||
model=self.model,
|
||||
model_config_path=self.model_config_path,
|
||||
is_intermediate=self.is_intermediate,
|
||||
collect_progress_latents=self.collect_progress_latents,
|
||||
caption_text=self.caption_text,
|
||||
)
|
||||
|
||||
return imagine_prompt
|
@ -0,0 +1,81 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, HttpUrl, validator
|
||||
|
||||
from imaginairy.http.utils import Base64Bytes
|
||||
|
||||
|
||||
class StableStudioPrompt(BaseModel):
|
||||
text: Optional[str] = None
|
||||
weight: Optional[float] = Field(None, ge=-1, le=1)
|
||||
|
||||
|
||||
class StableStudioModel(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
image: Optional[HttpUrl] = None
|
||||
|
||||
|
||||
class StableStudioStyle(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
image: Optional[HttpUrl] = None
|
||||
|
||||
|
||||
class StableStudioSampler(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class StableStudioInputImage(BaseModel):
|
||||
blob: Optional[Base64Bytes] = None
|
||||
weight: Optional[float] = Field(None, ge=0, le=1)
|
||||
|
||||
|
||||
class StableStudioImage(BaseModel):
|
||||
id: str
|
||||
created_at: Optional[datetime] = None
|
||||
input: Optional["StableStudioInput"] = None
|
||||
blob: Optional[Base64Bytes] = None
|
||||
|
||||
|
||||
class StableStudioImages(BaseModel):
|
||||
id: str
|
||||
exclusive_start_image_id: Optional[str] = None
|
||||
images: Optional[List[StableStudioImage]] = None
|
||||
|
||||
|
||||
class StableStudioInput(BaseModel, extra=Extra.forbid):
|
||||
prompts: Optional[List[StableStudioPrompt]] = None
|
||||
model: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
sampler: Optional[StableStudioSampler] = None
|
||||
cfg_scale: Optional[float] = Field(None, alias="cfgScale")
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
mask_image: Optional[StableStudioInputImage] = Field(None, alias="maskImage")
|
||||
initial_image: Optional[StableStudioInputImage] = Field(None, alias="initialImage")
|
||||
|
||||
@validator("seed")
|
||||
def validate_seed(cls, v): # noqa
|
||||
if v == 0:
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
class StableStudioBatchRequest(BaseModel):
|
||||
input: StableStudioInput
|
||||
count: int = 1
|
||||
|
||||
|
||||
class StableStudioBatchResponse(BaseModel):
|
||||
images: List[StableStudioImage]
|
||||
|
||||
|
||||
StableStudioInput.update_forward_refs()
|
||||
StableStudioImage.update_forward_refs()
|
@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from imaginairy.http.models import ImagineWebPrompt
|
||||
from imaginairy.http.stablestudio.models import (
|
||||
StableStudioBatchRequest,
|
||||
StableStudioBatchResponse,
|
||||
StableStudioImage,
|
||||
StableStudioModel,
|
||||
StableStudioSampler,
|
||||
)
|
||||
from imaginairy.http.utils import generate_image_b64
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/generate", response_model=StableStudioBatchResponse)
|
||||
async def generate(studio_request: StableStudioBatchRequest):
|
||||
from imaginairy.http.app import gpu_lock
|
||||
|
||||
generated_images = []
|
||||
imagine_prompt = ImagineWebPrompt.from_stable_studio_input(studio_request.input)
|
||||
starting_seed = imagine_prompt.seed if imagine_prompt.seed is not None else None
|
||||
|
||||
for run_num in range(studio_request.count):
|
||||
if starting_seed is not None:
|
||||
imagine_prompt.seed = starting_seed + run_num
|
||||
async with gpu_lock:
|
||||
img_base64 = await run_in_threadpool(generate_image_b64, imagine_prompt)
|
||||
|
||||
image = StableStudioImage(id=str(uuid.uuid4()), blob=img_base64)
|
||||
generated_images.append(image)
|
||||
|
||||
return StableStudioBatchResponse(images=generated_images)
|
||||
|
||||
|
||||
@router.get("/samplers")
|
||||
async def list_samplers():
|
||||
from imaginairy.config import SAMPLER_TYPE_OPTIONS
|
||||
|
||||
sampler_objs = []
|
||||
for sampler_type in SAMPLER_TYPE_OPTIONS:
|
||||
sampler_obj = StableStudioSampler(id=sampler_type, name=sampler_type)
|
||||
sampler_objs.append(sampler_obj)
|
||||
|
||||
return sampler_objs
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models():
|
||||
from imaginairy.config import MODEL_CONFIGS
|
||||
|
||||
model_objs = []
|
||||
for model_config in MODEL_CONFIGS:
|
||||
model_obj = StableStudioModel(
|
||||
id=model_config.short_name,
|
||||
name=model_config.description,
|
||||
description=model_config.description,
|
||||
)
|
||||
model_objs.append(model_obj)
|
||||
|
||||
return model_objs
|
@ -0,0 +1,39 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from imaginairy import imagine
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
"""ImagineWebPrompt to generated image"""
|
||||
prompt = prompt.to_imagine_prompt()
|
||||
result = next(imagine([prompt]))
|
||||
img = result.images["generated"]
|
||||
img_io = BytesIO()
|
||||
img.save(img_io, "JPEG")
|
||||
img_io.seek(0)
|
||||
return img_io
|
||||
|
||||
|
||||
def generate_image_b64(prompt):
|
||||
"""ImagineWebPrompt to generated base64 encoded image"""
|
||||
img_io = generate_image(prompt)
|
||||
img_base64 = base64.b64encode(img_io.getvalue())
|
||||
return img_base64
|
||||
|
||||
|
||||
class Base64Bytes(bytes):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
return base64.b64decode(v)
|
||||
raise ValueError("Byte value must be either str or bytes")
|
||||
|
||||
def __str__(self):
|
||||
return base64.b64encode(self).decode()
|
Loading…
Reference in New Issue