mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
316114e660
Wrote an openai script and custom prompt to generate them.
79 lines
2.0 KiB
Python
79 lines
2.0 KiB
Python
"""FastAPI application for image generation"""
|
|
|
|
import logging
|
|
import os.path
|
|
import sys
|
|
import traceback
|
|
from asyncio import Lock
|
|
|
|
from fastapi import FastAPI, Query, Request
|
|
from fastapi.concurrency import run_in_threadpool
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from imaginairy.http_app.stablestudio import routes
|
|
from imaginairy.http_app.utils import generate_image
|
|
from imaginairy.schema import ImaginePrompt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
static_folder = os.path.dirname(os.path.abspath(__file__)) + "/stablestudio/dist"
|
|
|
|
|
|
gpu_lock = Lock()
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=[
|
|
"http://localhost:3000",
|
|
"http://localhost:3001",
|
|
"http://localhost:3002",
|
|
],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app.include_router(routes.router, prefix="/api/stablestudio")
|
|
|
|
|
|
@app.post("/api/imagine")
|
|
async def imagine_endpoint(prompt: ImaginePrompt):
|
|
async with gpu_lock:
|
|
img_io = await run_in_threadpool(generate_image, prompt)
|
|
return StreamingResponse(img_io, media_type="image/jpg")
|
|
|
|
|
|
@app.get("/api/imagine")
|
|
async def imagine_get_endpoint(text: str = Query(...)):
|
|
async with gpu_lock:
|
|
img_io = await run_in_threadpool(generate_image, ImaginePrompt(prompt=text))
|
|
return StreamingResponse(img_io, media_type="image/jpg")
|
|
|
|
|
|
@app.get("/edit")
|
|
async def edit_redir():
|
|
return FileResponse(f"{static_folder}/index.html")
|
|
|
|
|
|
@app.get("/generate")
|
|
async def generate_redir():
|
|
return FileResponse(f"{static_folder}/index.html")
|
|
|
|
|
|
app.mount("/", StaticFiles(directory=static_folder, html=True), name="static")
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def exception_handler(request: Request, exc: Exception):
|
|
print(f"Unhandled error: {exc}", file=sys.stderr)
|
|
traceback.print_exc(file=sys.stderr)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"message": "Internal Server Error"},
|
|
)
|