You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/web_app/main.py

57 lines
1.7 KiB
Python

"""Manifest as an app service."""
from typing import Any, Dict, cast
from fastapi import APIRouter, FastAPI, HTTPException
from manifest import Manifest
from manifest.response import Response as ManifestResponse
from web_app import schemas
app = FastAPI()
api_router = APIRouter()
@app.get("/")
async def root() -> Dict:
"""Root endpoint."""
return {"message": "Hello to the Manifest App"}
@api_router.post("/prompt/", status_code=201, response_model=schemas.ManifestResponse)
def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict:
"""Prompt a manifest session and query."""
manifest = Manifest(
client_name=manifest_in.client_name,
client_connection=manifest_in.client_connection,
engine=manifest_in.engine,
cache_name=manifest_in.cache_name,
cache_connection=manifest_in.cache_connection,
)
manifest_prompt_args: Dict[str, Any] = {
"n": manifest_in.n,
"max_tokens": manifest_in.max_tokens,
}
if manifest_in.temperature:
manifest_prompt_args["temperature"] = manifest_in.temperature
if manifest_in.top_k:
manifest_prompt_args["top_k"] = manifest_in.top_k
if manifest_in.top_p:
manifest_prompt_args["top_p"] = manifest_in.top_p
try:
response = manifest.run(
prompt=manifest_in.prompt, return_response=True, **manifest_prompt_args
)
response = cast(ManifestResponse, response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return {
"response": response.get_response(),
"cached": response.is_cached(),
"request_params": response.get_request_obj(),
}
app.include_router(api_router)