You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
EVAL/api/main.py

111 lines
3.0 KiB
Python

1 year ago
import re
from typing import Dict, List, TypedDict
1 year ago
import uvicorn
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.prompts.error import ERROR_PROMPT
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from logger import logger
from pydantic import BaseModel
1 year ago
app = FastAPI()
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
uploader = StaticUploader.from_settings(settings)
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers)
1 year ago
class Request(BaseModel):
key: str
query: str
files: List[str]
1 year ago
class Response(TypedDict):
response: str
files: List[str]
1 year ago
@app.get("/")
async def index():
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
1 year ago
@app.post("/command")
async def command(request: Request) -> Response:
query = request.query
1 year ago
files = request.files
session = request.key
1 year ago
executor = agent_manager.get_or_create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
1 year ago
try:
res = executor({"input": promptedQuery})
1 year ago
except Exception as e:
logger.error(f"error while processing request: {str(e)}")
1 year ago
try:
res = executor(
1 year ago
{
"input": ERROR_PROMPT.format(promptedQuery=promptedQuery, e=str(e)),
1 year ago
}
)
except Exception as e:
return {"response": str(e), "files": []}
1 year ago
images = re.findall("(image/\S*png)", res["output"])
dataframes = re.findall("(dataframe/\S*csv)", res["output"])
1 year ago
return {
"response": res["output"],
"files": [uploader.upload(image) for image in images]
+ [uploader.upload(dataframe) for dataframe in dataframes],
1 year ago
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])