feature: API support for StableStudio

pull/339/head
Bryce 1 year ago committed by Bryce Drennan
parent e53459a50a
commit 8e28a2ed02

@ -486,8 +486,9 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
**13.0.0**
- 🎉 feature: multi-controlnet support. pass in multiple `--control-mode`, `--control-image`, and `--control-image-raw` arguments.
- 🎉 feature: add colorization controlnet. improve `aimg colorize` command
- 🧪 feature: "better" memory management. If GPU is full, least-recently-used model is moved to RAM. I'm not confident this works well.
- 🧪 alpha feature: `aimg run-api-server` command. Runs a http webserver (not finished). After running, visit http://127.0.0.1:8000/docs for api.
- 🎉🧪 feature: API server `aimg server` command. Runs a http webserver (not finished). After running, visit http://127.0.0.1:8000/docs for api.
- 🎉🧪 feature: API support for [Stablity AI's new open-source Generative AI interface, StableStudio](https://github.com/Stability-AI/StableStudio).
- 🎉🧪 feature: "better" memory management. If GPU is full, least-recently-used model is moved to RAM. I'm not confident this works well.
- feature: [disabled] inpainting controlnet can be used instead of finetuned inpainting model
- The inpainting controlnet doesn't work as well as the finetuned model
- feature: python interface allows configuration of controlnet strength

@ -8,7 +8,7 @@ from imaginairy.cli.describe import describe_cmd
from imaginairy.cli.edit import edit_cmd
from imaginairy.cli.edit_demo import edit_demo_cmd
from imaginairy.cli.imagine import imagine_cmd
from imaginairy.cli.run_api import run_api_server_cmd
from imaginairy.cli.run_api import run_server_cmd
from imaginairy.cli.train import prep_images_cmd, prune_ckpt_cmd, train_concept_cmd
from imaginairy.cli.upscale import upscale_cmd
@ -49,7 +49,7 @@ aimg.add_command(prep_images_cmd, name="prep-images")
aimg.add_command(prune_ckpt_cmd, name="prune-ckpt")
aimg.add_command(train_concept_cmd, name="train-concept")
aimg.add_command(upscale_cmd, name="upscale")
aimg.add_command(run_api_server_cmd, name="run-api-server")
aimg.add_command(run_server_cmd, name="server")
@aimg.command()

@ -1,8 +1,8 @@
import click
@click.command("run-api-server")
def run_api_server_cmd():
@click.command("run-server")
def run_server_cmd():
"""Run a HTTP API server."""
import uvicorn

@ -1,81 +1,42 @@
from asyncio import Lock
from io import BytesIO
from typing import Optional
from fastapi import FastAPI, Query
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel # noqa
from imaginairy import ImaginePrompt, imagine
from imaginairy.log_utils import configure_logging
from imaginairy.http.models import ImagineWebPrompt
from imaginairy.http.stablestudio import routes
from imaginairy.http.utils import generate_image
app = FastAPI()
lock = Lock()
class ImagineWebPrompt(BaseModel):
class Config:
arbitrary_types_allowed = True
gpu_lock = Lock()
prompt: Optional[str]
negative_prompt: Optional[str]
prompt_strength: float = 7.5
# init_image: Optional[Union[LazyLoadingImage, str]]
init_image_strength: Optional[float] = None
# control_inputs: Optional[List[ControlInput]] = None
mask_prompt: Optional[str] = None
# mask_image: Optional[Union[LazyLoadingImage, str]] = None
mask_mode: str = "replace"
mask_modify_original: bool = True
outpaint: Optional[str] = None
seed: Optional[int] = None
steps: Optional[int] = None
height: Optional[int] = None
width: Optional[int] = None
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: float = 0.2
# sampler_type: str = Field(..., alias='config.DEFAULT_SAMPLER') # update the alias based on actual config field name
conditioning: Optional[str] = None
tile_mode: str = ""
allow_compose_phase: bool = True
# model: str = Field(..., alias='config.DEFAULT_MODEL') # update the alias based on actual config field name
model_config_path: Optional[str] = None
is_intermediate: bool = False
collect_progress_latents: bool = False
caption_text: str = ""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:3002",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def generate_image(prompt: ImagineWebPrompt):
prompt = ImaginePrompt(prompt.prompt)
result = next(imagine([prompt]))
return result.images["generated"]
app.include_router(routes.router, prefix="/api/stablestudio")
@app.post("/api/imagine")
async def imagine_endpoint(prompt: ImagineWebPrompt):
async with lock:
img = await run_in_threadpool(generate_image, prompt)
img_io = BytesIO()
img.save(img_io, "JPEG")
img_io.seek(0)
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, prompt)
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/api/imagine")
async def imagine_get_endpoint(text: str = Query(...)):
async with lock:
img = await run_in_threadpool(generate_image, ImagineWebPrompt(prompt=text))
img_io = BytesIO()
img.save(img_io, "JPEG")
img_io.seek(0)
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, ImagineWebPrompt(prompt=text))
return StreamingResponse(img_io, media_type="image/jpg")
if __name__ == "__main__":
import uvicorn
configure_logging()
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

@ -0,0 +1,114 @@
from typing import Optional
from pydantic import BaseModel
from imaginairy.http.utils import Base64Bytes
class ImagineWebPrompt(BaseModel):
class Config:
arbitrary_types_allowed = True
prompt: Optional[str]
negative_prompt: Optional[str] = None
prompt_strength: Optional[float] = None
init_image: Optional[Base64Bytes] = None
init_image_strength: Optional[float] = None
# control_inputs: Optional[List[ControlInput]] = None
mask_prompt: Optional[str] = None
mask_image: Optional[Base64Bytes] = None
mask_mode: str = "replace"
mask_modify_original: bool = True
outpaint: Optional[str] = None
seed: Optional[int] = None
steps: Optional[int] = None
height: Optional[int] = None
width: Optional[int] = None
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: float = 0.2
sampler_type: Optional[str] = None
conditioning: Optional[str] = None
tile_mode: str = ""
allow_compose_phase: bool = True
model: Optional[str] = None
model_config_path: Optional[str] = None
is_intermediate: bool = False
collect_progress_latents: bool = False
caption_text: str = ""
@classmethod
def from_stable_studio_input(cls, stable_input):
positive_prompt = stable_input.prompts[0].text
negative_prompt = (
stable_input.prompts[1].text if len(stable_input.prompts) > 1 else None
)
init_image = None
init_image_strength = None
if stable_input.initial_image:
init_image = stable_input.initial_image.blob
init_image_strength = stable_input.initial_image.weight
mask_image = stable_input.mask_image.blob if stable_input.mask_image else None
sampler_type = stable_input.sampler.id if stable_input.sampler else None
return cls(
prompt=positive_prompt,
prompt_strength=stable_input.cfg_scale,
negative_prompt=negative_prompt,
model=stable_input.model,
sampler_type=sampler_type,
seed=stable_input.seed,
steps=stable_input.steps,
height=stable_input.height,
width=stable_input.width,
init_image=init_image,
init_image_strength=init_image_strength,
mask_image=mask_image,
mask_mode="keep",
)
def to_imagine_prompt(self):
from io import BytesIO
from PIL import Image
from imaginairy import ImaginePrompt
imagine_prompt = ImaginePrompt(
prompt=self.prompt,
negative_prompt=self.negative_prompt,
prompt_strength=self.prompt_strength,
init_image=Image.open(BytesIO(self.init_image))
if self.init_image
else None,
init_image_strength=self.init_image_strength,
# control_inputs=self.control_inputs, # Uncomment this if the control_inputs field exists in ImagineWebPrompt
mask_prompt=self.mask_prompt,
mask_image=Image.open(BytesIO(self.mask_image))
if self.mask_image
else None,
mask_mode=self.mask_mode,
mask_modify_original=self.mask_modify_original,
outpaint=self.outpaint,
seed=self.seed,
steps=self.steps,
height=self.height,
width=self.width,
upscale=self.upscale,
fix_faces=self.fix_faces,
fix_faces_fidelity=self.fix_faces_fidelity,
sampler_type=self.sampler_type,
conditioning=self.conditioning,
tile_mode=self.tile_mode,
allow_compose_phase=self.allow_compose_phase,
model=self.model,
model_config_path=self.model_config_path,
is_intermediate=self.is_intermediate,
collect_progress_latents=self.collect_progress_latents,
caption_text=self.caption_text,
)
return imagine_prompt

@ -0,0 +1,81 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Extra, Field, HttpUrl, validator
from imaginairy.http.utils import Base64Bytes
class StableStudioPrompt(BaseModel):
text: Optional[str] = None
weight: Optional[float] = Field(None, ge=-1, le=1)
class StableStudioModel(BaseModel):
id: str
name: Optional[str] = None
description: Optional[str] = None
image: Optional[HttpUrl] = None
class StableStudioStyle(BaseModel):
id: str
name: Optional[str] = None
description: Optional[str] = None
image: Optional[HttpUrl] = None
class StableStudioSampler(BaseModel):
id: str
name: Optional[str] = None
class StableStudioInputImage(BaseModel):
blob: Optional[Base64Bytes] = None
weight: Optional[float] = Field(None, ge=0, le=1)
class StableStudioImage(BaseModel):
id: str
created_at: Optional[datetime] = None
input: Optional["StableStudioInput"] = None
blob: Optional[Base64Bytes] = None
class StableStudioImages(BaseModel):
id: str
exclusive_start_image_id: Optional[str] = None
images: Optional[List[StableStudioImage]] = None
class StableStudioInput(BaseModel, extra=Extra.forbid):
prompts: Optional[List[StableStudioPrompt]] = None
model: Optional[str] = None
style: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
sampler: Optional[StableStudioSampler] = None
cfg_scale: Optional[float] = Field(None, alias="cfgScale")
steps: Optional[int] = None
seed: Optional[int] = None
mask_image: Optional[StableStudioInputImage] = Field(None, alias="maskImage")
initial_image: Optional[StableStudioInputImage] = Field(None, alias="initialImage")
@validator("seed")
def validate_seed(cls, v): # noqa
if v == 0:
return None
return v
class StableStudioBatchRequest(BaseModel):
input: StableStudioInput
count: int = 1
class StableStudioBatchResponse(BaseModel):
images: List[StableStudioImage]
StableStudioInput.update_forward_refs()
StableStudioImage.update_forward_refs()

@ -0,0 +1,64 @@
import uuid
from fastapi import APIRouter
from fastapi.concurrency import run_in_threadpool
from imaginairy.http.models import ImagineWebPrompt
from imaginairy.http.stablestudio.models import (
StableStudioBatchRequest,
StableStudioBatchResponse,
StableStudioImage,
StableStudioModel,
StableStudioSampler,
)
from imaginairy.http.utils import generate_image_b64
router = APIRouter()
@router.post("/generate", response_model=StableStudioBatchResponse)
async def generate(studio_request: StableStudioBatchRequest):
from imaginairy.http.app import gpu_lock
generated_images = []
imagine_prompt = ImagineWebPrompt.from_stable_studio_input(studio_request.input)
starting_seed = imagine_prompt.seed if imagine_prompt.seed is not None else None
for run_num in range(studio_request.count):
if starting_seed is not None:
imagine_prompt.seed = starting_seed + run_num
async with gpu_lock:
img_base64 = await run_in_threadpool(generate_image_b64, imagine_prompt)
image = StableStudioImage(id=str(uuid.uuid4()), blob=img_base64)
generated_images.append(image)
return StableStudioBatchResponse(images=generated_images)
@router.get("/samplers")
async def list_samplers():
from imaginairy.config import SAMPLER_TYPE_OPTIONS
sampler_objs = []
for sampler_type in SAMPLER_TYPE_OPTIONS:
sampler_obj = StableStudioSampler(id=sampler_type, name=sampler_type)
sampler_objs.append(sampler_obj)
return sampler_objs
@router.get("/models")
async def list_models():
from imaginairy.config import MODEL_CONFIGS
model_objs = []
for model_config in MODEL_CONFIGS:
model_obj = StableStudioModel(
id=model_config.short_name,
name=model_config.description,
description=model_config.description,
)
model_objs.append(model_obj)
return model_objs

@ -0,0 +1,39 @@
import base64
from io import BytesIO
from imaginairy import imagine
def generate_image(prompt):
"""ImagineWebPrompt to generated image"""
prompt = prompt.to_imagine_prompt()
result = next(imagine([prompt]))
img = result.images["generated"]
img_io = BytesIO()
img.save(img_io, "JPEG")
img_io.seek(0)
return img_io
def generate_image_b64(prompt):
"""ImagineWebPrompt to generated base64 encoded image"""
img_io = generate_image(prompt)
img_base64 = base64.b64encode(img_io.getvalue())
return img_base64
class Base64Bytes(bytes):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if isinstance(v, bytes):
return v
if isinstance(v, str):
return base64.b64decode(v)
raise ValueError("Byte value must be either str or bytes")
def __str__(self):
return base64.b64encode(self).decode()

@ -181,6 +181,9 @@ class ImaginePrompt:
from imaginairy.samplers import SAMPLER_LOOKUP, SamplerName
self.prompts = self.process_prompt_input(self.prompts)
self.prompt_strength = (
7.5 if self.prompt_strength is None else self.prompt_strength
)
if self.tile_mode is True:
self.tile_mode = "xy"

@ -25,3 +25,4 @@ ignore = D104
[pylama:pylint]
generated_members=torch.*
extension-pkg-whitelist=pydantic

Loading…
Cancel
Save