imaginAIry/imaginairy/http_app/app.py
Bryce 316114e660 docs: add docstrings
Wrote an openai script and custom prompt to generate them.
2023-12-15 14:32:01 -08:00

79 lines
2.0 KiB
Python

"""FastAPI application for image generation"""
import logging
import os.path
import sys
import traceback
from asyncio import Lock
from fastapi import FastAPI, Query, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from imaginairy.http_app.stablestudio import routes
from imaginairy.http_app.utils import generate_image
from imaginairy.schema import ImaginePrompt
logger = logging.getLogger(__name__)
static_folder = os.path.dirname(os.path.abspath(__file__)) + "/stablestudio/dist"
gpu_lock = Lock()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:3002",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(routes.router, prefix="/api/stablestudio")
@app.post("/api/imagine")
async def imagine_endpoint(prompt: ImaginePrompt):
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, prompt)
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/api/imagine")
async def imagine_get_endpoint(text: str = Query(...)):
async with gpu_lock:
img_io = await run_in_threadpool(generate_image, ImaginePrompt(prompt=text))
return StreamingResponse(img_io, media_type="image/jpg")
@app.get("/edit")
async def edit_redir():
return FileResponse(f"{static_folder}/index.html")
@app.get("/generate")
async def generate_redir():
return FileResponse(f"{static_folder}/index.html")
app.mount("/", StaticFiles(directory=static_folder, html=True), name="static")
@app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception):
print(f"Unhandled error: {exc}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
return JSONResponse(
status_code=500,
content={"message": "Internal Server Error"},
)