You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
EVAL/api/main.py

117 lines
3.0 KiB
Python

import re
from typing import Dict, List, TypedDict
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.handlers.image import ImageCaptioning
from core.prompts.error import ERROR_PROMPT
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet, WineDB
from core.tools.editor import CodeEditor
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
from logger import logger
app = FastAPI()
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
uploader = StaticUploader.from_settings(settings)
use_gpu = settings["USE_GPU"] and torch.cuda.is_available()
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
if use_gpu:
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers: Dict[FileType, BaseHandler] = {}
handlers[FileType.DATAFRAME] = CsvToDataframe()
if use_gpu:
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
toolsets.append(WineDB())
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers)
class Request(BaseModel):
key: str
query: str
files: List[str]
class Response(TypedDict):
response: str
files: List[str]
@app.get("/")
async def index():
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
@app.post("/command")
async def command(request: Request) -> Response:
query = request.query
files = request.files
session = request.key
executor = agent_manager.get_or_create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
try:
res = executor({"input": promptedQuery})
except Exception as e:
logger.error(f"error while processing request: {str(e)}")
try:
res = executor(
{
"input": ERROR_PROMPT.format(promptedQuery=promptedQuery, e=str(e)),
}
)
except Exception as e:
return {"response": str(e), "files": []}
images = re.findall("(image/\S*png)", res["output"])
dataframes = re.findall("(dataframe/\S*csv)", res["output"])
return {
"response": res["output"],
"files": [uploader.upload(image) for image in images]
+ [uploader.upload(dataframe) for dataframe in dataframes],
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])