gpt4all/gpt4all-api/gpt4all_api/app/main.py
Andriy Mulyar 633e2a2137
GPT4All API Scaffolding. Matches OpenAI OpenAPI spec for chats and completions (#839)
* GPT4All API Scaffolding. Matches OpenAI OpenAI spec for engines, chats and completions

* Edits for docker building

* FastAPI app builds and pydantic models are accurate

* Added groovy download into dockerfile

* improved dockerfile

* Chat completions endpoint edits

* API uni test sketch

* Working example of groovy inference with open ai api

* Added lines to test

* Set default to mpt
2023-06-28 14:28:52 -04:00

61 lines
2.0 KiB
Python

import os
import docs
import logging
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.cors import CORSMiddleware
from fastapi.logger import logger as fastapi_logger
from api_v1.settings import settings
from api_v1.api import router as v1_router
from api_v1 import events
import os
logger = logging.getLogger(__name__)
app = FastAPI(title='GPT4All API', description=docs.desc)
#CORS Configuration (in-case you want to deploy)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
logger.info('Adding v1 endpoints..')
# add v1
app.include_router(v1_router, prefix='/v1')
app.add_event_handler('startup', events.startup_event_handler(app))
app.add_exception_handler(HTTPException, events.on_http_error)
@app.on_event("startup")
async def startup():
global model
logger.info(f"Downloading/fetching model: {os.path.join(settings.gpt4all_path, settings.model)}")
from gpt4all import GPT4All
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
logger.info("GPT4All API is ready.")
@app.on_event("shutdown")
async def shutdown():
logger.info("Shutting down API")
# This is needed to get logs to show up in the app
if "gunicorn" in os.environ.get("SERVER_SOFTWARE", ""):
gunicorn_error_logger = logging.getLogger("gunicorn.error")
gunicorn_logger = logging.getLogger("gunicorn")
root_logger = logging.getLogger()
fastapi_logger.setLevel(gunicorn_logger.level)
fastapi_logger.handlers = gunicorn_error_logger.handlers
root_logger.setLevel(gunicorn_logger.level)
uvicorn_logger = logging.getLogger("uvicorn.access")
uvicorn_logger.handlers = gunicorn_error_logger.handlers
else:
# https://github.com/tiangolo/fastapi/issues/2019
LOG_FORMAT2 = "[%(asctime)s %(process)d:%(threadName)s] %(name)s - %(levelname)s - %(message)s | %(filename)s:%(lineno)d"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT2)