You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
EVAL/api/main.py

129 lines
3.3 KiB
Python

from typing import Dict, List, TypedDict
import re
import uvicorn
import torch
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from env import settings
from core.prompts.error import ERROR_PROMPT
from core.agents.manager import AgentManager
from core.tools.base import BaseToolSet
from core.tools.terminal import Terminal
from core.tools.editor import CodeEditor
from core.tools.cpu import (
RequestsGet,
WineDB,
ExitConversation,
)
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.image import ImageCaptioning
from core.handlers.dataframe import CsvToDataframe
from core.upload import StaticUploader
from logger import logger
app = FastAPI()
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
uploader = StaticUploader.from_settings(settings)
use_gpu = settings["USE_GPU"] and torch.cuda.is_available()
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
if use_gpu:
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers: Dict[FileType, BaseHandler] = {}
handlers[FileType.DATAFRAME] = CsvToDataframe()
if use_gpu:
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
toolsets.append(WineDB())
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers)
class Request(BaseModel):
key: str
query: str
files: List[str]
class Response(TypedDict):
response: str
files: List[str]
@app.get("/")
async def index():
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
@app.post("/command")
async def command(request: Request) -> Response:
query = request.query
files = request.files
session = request.key
logger.info("=============== Running =============")
logger.info(f"Query: {query}, Files: {files}")
executor = agent_manager.get_or_create_executor(session)
logger.info(f"======> Previous memory:\n\t{executor.memory}")
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
logger.info(f"======> Prompted Text:\n\t{promptedQuery}")
try:
res = executor({"input": promptedQuery})
except Exception as e:
logger.error(f"error while processing request: {str(e)}")
try:
res = executor(
{
"input": ERROR_PROMPT.format(promptedQuery=promptedQuery, e=str(e)),
}
)
except Exception as e:
return {"response": str(e), "files": []}
images = re.findall("(image/\S*png)", res["output"])
dataframes = re.findall("(dataframe/\S*csv)", res["output"])
return {
"response": res["output"],
"files": [uploader.upload(image) for image in images]
+ [uploader.upload(dataframe) for dataframe in dataframes],
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])