You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
EVAL/api/main.py

104 lines
2.6 KiB
Python

1 year ago
import re
from typing import Dict, List, TypedDict
1 year ago
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from ansi import ANSI, Color, Style, dim_multiline
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.prompts.error import ERROR_PROMPT
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
from logger import logger
1 year ago
app = FastAPI()
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
uploader = StaticUploader.from_settings(settings)
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers)
1 year ago
class Request(BaseModel):
key: str
query: str
files: List[str]
1 year ago
class Response(TypedDict):
response: str
files: List[str]
1 year ago
@app.get("/")
async def index():
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
1 year ago
@app.post("/command")
async def command(request: Request) -> Response:
query = request.query
1 year ago
files = request.files
session = request.key
1 year ago
executor = agent_manager.get_or_create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
1 year ago
try:
res = executor({"input": promptedQuery})
1 year ago
except Exception as e:
return {"response": str(e), "files": []}
files = re.findall("(image/\S*png)|(dataframe/\S*csv)", res["output"])
1 year ago
return {
"response": res["output"],
"files": [uploader.upload(file) for file in files],
1 year ago
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])