imaginAIry/imaginairy/http_app/app.py

77 lines
2.0 KiB
Python
Raw Normal View History

import logging
import os.path
import sys
import traceback
from asyncio import Lock
from fastapi import FastAPI, Query, Request
from fastapi.concurrency import run_in_threadpool
2023-05-22 07:17:59 +00:00
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from imaginairy.http_app.stablestudio import routes
from imaginairy.http_app.utils import generate_image
from imaginairy.schema import ImaginePrompt
logger = logging.getLogger(__name__)
static_folder = os.path.dirname(os.path.abspath(__file__)) + "/stablestudio/dist"
2023-05-22 07:17:59 +00:00
gpu_lock = Lock()
2023-05-22 07:17:59 +00:00
app = FastAPI()
2023-05-22 07:17:59 +00:00
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:3002",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2023-05-22 07:17:59 +00:00
app.include_router(routes.router, prefix="/api/stablestudio")
@app.post("/api/imagine")
async def imagine_endpoint(prompt: ImaginePrompt):
2023-05-22 07:17:59 +00:00
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, prompt)
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/api/imagine")
async def imagine_get_endpoint(text: str = Query(...)):
2023-05-22 07:17:59 +00:00
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, ImaginePrompt(prompt=text))
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/edit")
async def edit_redir():
return FileResponse(f"{static_folder}/index.html")
@app.get("/generate")
async def generate_redir():
return FileResponse(f"{static_folder}/index.html")
app.mount("/", StaticFiles(directory=static_folder, html=True), name="static")
@app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception):
print(f"Unhandled error: {exc}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
return JSONResponse(
status_code=500,
content={"message": "Internal Server Error"},
)