EVAL/api/main.py

129 lines
3.3 KiB
Python
Raw Normal View History

2023-03-18 12:42:14 +00:00
from typing import Dict, List, TypedDict
2023-03-17 15:55:15 +00:00
import re
import uvicorn
2023-03-17 15:55:15 +00:00
import torch
2023-03-17 15:55:15 +00:00
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
2023-03-17 15:55:15 +00:00
from pydantic import BaseModel
2023-03-20 08:27:20 +00:00
2023-03-18 12:42:14 +00:00
from env import settings
2023-03-17 15:55:15 +00:00
from core.prompts.error import ERROR_PROMPT
from core.agents.manager import AgentManager
from core.tools.base import BaseToolSet
from core.tools.terminal import Terminal
from core.tools.editor import CodeEditor
from core.tools.cpu import (
2023-03-18 12:42:14 +00:00
RequestsGet,
WineDB,
2023-03-22 03:21:29 +00:00
ExitConversation,
2023-03-18 12:42:14 +00:00
)
from core.tools.gpu import (
2023-03-18 12:42:14 +00:00
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.image import ImageCaptioning
from core.handlers.dataframe import CsvToDataframe
from core.upload import StaticUploader
2023-03-22 00:34:52 +00:00
from logger import logger
2023-03-17 15:55:15 +00:00
app = FastAPI()
2023-03-18 12:42:14 +00:00
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
uploader = StaticUploader.from_settings(settings)
use_gpu = settings["USE_GPU"] and torch.cuda.is_available()
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
if use_gpu:
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers: Dict[FileType, BaseHandler] = {}
handlers[FileType.DATAFRAME] = CsvToDataframe()
if use_gpu:
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
2023-03-18 12:42:14 +00:00
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
toolsets.append(WineDB())
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers)
2023-03-17 15:55:15 +00:00
class Request(BaseModel):
key: str
2023-03-18 06:05:02 +00:00
query: str
files: List[str]
2023-03-17 15:55:15 +00:00
class Response(TypedDict):
response: str
2023-03-18 06:05:02 +00:00
files: List[str]
2023-03-17 15:55:15 +00:00
@app.get("/")
async def index():
2023-03-20 08:27:20 +00:00
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
2023-03-17 15:55:15 +00:00
@app.post("/command")
async def command(request: Request) -> Response:
2023-03-18 06:05:02 +00:00
query = request.query
2023-03-17 15:55:15 +00:00
files = request.files
session = request.key
2023-03-17 15:55:15 +00:00
2023-03-22 00:34:52 +00:00
logger.info("=============== Running =============")
logger.info(f"Query: {query}, Files: {files}")
executor = agent_manager.get_or_create_executor(session)
2023-03-22 00:34:52 +00:00
logger.info(f"======> Previous memory:\n\t{executor.memory}")
2023-03-17 15:55:15 +00:00
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
2023-03-18 06:05:02 +00:00
promptedQuery += query
2023-03-22 00:34:52 +00:00
logger.info(f"======> Prompted Text:\n\t{promptedQuery}")
2023-03-17 15:55:15 +00:00
try:
res = executor({"input": promptedQuery})
2023-03-17 15:55:15 +00:00
except Exception as e:
logger.error(f"error while processing request: {str(e)}")
2023-03-17 15:55:15 +00:00
try:
res = executor(
2023-03-17 15:55:15 +00:00
{
2023-03-18 06:05:02 +00:00
"input": ERROR_PROMPT.format(promptedQuery=promptedQuery, e=str(e)),
2023-03-17 15:55:15 +00:00
}
)
except Exception as e:
2023-03-18 06:05:02 +00:00
return {"response": str(e), "files": []}
2023-03-17 15:55:15 +00:00
images = re.findall("(image/\S*png)", res["output"])
2023-03-18 06:05:02 +00:00
dataframes = re.findall("(dataframe/\S*csv)", res["output"])
2023-03-17 15:55:15 +00:00
return {
"response": res["output"],
"files": [uploader.upload(image) for image in images]
+ [uploader.upload(dataframe) for dataframe in dataframes],
2023-03-17 15:55:15 +00:00
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])