You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
EVAL/api/main.py

104 lines
2.6 KiB
Python

import re
from typing import Dict, List, TypedDict
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from ansi import ANSI, Color, Style, dim_multiline
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.prompts.error import ERROR_PROMPT
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
from logger import logger
app = FastAPI()
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
uploader = StaticUploader.from_settings(settings)
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers)
class Request(BaseModel):
key: str
query: str
files: List[str]
class Response(TypedDict):
response: str
files: List[str]
@app.get("/")
async def index():
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
@app.post("/command")
async def command(request: Request) -> Response:
query = request.query
files = request.files
session = request.key
executor = agent_manager.get_or_create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
try:
res = executor({"input": promptedQuery})
except Exception as e:
return {"response": str(e), "files": []}
files = re.findall("(image/\S*png)|(dataframe/\S*csv)", res["output"])
return {
"response": res["output"],
"files": [uploader.upload(file) for file in files],
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])