mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
"""Manifest as an app service."""
|
|
|
|
from typing import Any, Dict, cast
|
|
|
|
from fastapi import APIRouter, FastAPI, HTTPException
|
|
|
|
from manifest import Manifest
|
|
from manifest.response import Response as ManifestResponse
|
|
from web_app import schemas
|
|
|
|
app = FastAPI()
|
|
api_router = APIRouter()
|
|
|
|
|
|
@app.get("/")
|
|
async def root() -> Dict:
|
|
"""Root endpoint."""
|
|
return {"message": "Hello to the Manifest App"}
|
|
|
|
|
|
@api_router.post("/prompt/", status_code=201, response_model=schemas.ManifestResponse)
|
|
def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict:
|
|
"""Prompt a manifest session and query."""
|
|
manifest = Manifest(
|
|
client_name=manifest_in.client_name,
|
|
client_connection=manifest_in.client_connection,
|
|
engine=manifest_in.engine,
|
|
cache_name=manifest_in.cache_name,
|
|
cache_connection=manifest_in.cache_connection,
|
|
)
|
|
manifest_prompt_args: Dict[str, Any] = {
|
|
"n": manifest_in.n,
|
|
"max_tokens": manifest_in.max_tokens,
|
|
}
|
|
if manifest_in.temperature:
|
|
manifest_prompt_args["temperature"] = manifest_in.temperature
|
|
if manifest_in.top_k:
|
|
manifest_prompt_args["top_k"] = manifest_in.top_k
|
|
if manifest_in.top_p:
|
|
manifest_prompt_args["top_p"] = manifest_in.top_p
|
|
|
|
try:
|
|
response = manifest.run(
|
|
prompt=manifest_in.prompt, return_response=True, **manifest_prompt_args
|
|
)
|
|
response = cast(ManifestResponse, response)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
return {
|
|
"response": response.get_response(),
|
|
"cached": response.is_cached(),
|
|
"request_params": response.get_request_obj(),
|
|
}
|
|
|
|
|
|
app.include_router(api_router)
|