You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/http_app/app.py

79 lines
2.0 KiB
Python

"""FastAPI application for image generation"""
import logging
import os.path
import sys
import traceback
from asyncio import Lock
from fastapi import FastAPI, Query, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from imaginairy.http_app.stablestudio import routes
from imaginairy.http_app.utils import generate_image
from imaginairy.schema import ImaginePrompt
logger = logging.getLogger(__name__)
static_folder = os.path.dirname(os.path.abspath(__file__)) + "/stablestudio/dist"
gpu_lock = Lock()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:3002",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(routes.router, prefix="/api/stablestudio")
@app.post("/api/imagine")
async def imagine_endpoint(prompt: ImaginePrompt):
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, prompt)
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/api/imagine")
async def imagine_get_endpoint(text: str = Query(...)):
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, ImaginePrompt(prompt=text))
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/edit")
async def edit_redir():
return FileResponse(f"{static_folder}/index.html")
@app.get("/generate")
async def generate_redir():
return FileResponse(f"{static_folder}/index.html")
app.mount("/", StaticFiles(directory=static_folder, html=True), name="static")
@app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception):
print(f"Unhandled error: {exc}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
return JSONResponse(
status_code=500,
content={"message": "Internal Server Error"},
)