import os import re from typing import Dict, List, TypedDict import uvicorn from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from ansi import ANSI, Color, Style, dim_multiline from core.agents.manager import AgentManager from core.handlers.base import BaseHandler, FileHandler, FileType from core.handlers.dataframe import CsvToDataframe from core.prompts.error import ERROR_PROMPT from core.tools.base import BaseToolSet from core.tools.cpu import ExitConversation, RequestsGet from core.tools.editor import CodeEditor from core.tools.terminal import Terminal from core.upload import StaticUploader from env import settings from logger import logger app = FastAPI() app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static") uploader = StaticUploader.from_settings(settings) os.chdir(settings["PLAYGROUND_DIR"]) toolsets: List[BaseToolSet] = [ Terminal(), CodeEditor(), RequestsGet(), ExitConversation(), ] handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} if settings["USE_GPU"]: import torch from core.handlers.image import ImageCaptioning from core.tools.gpu import ( ImageEditing, InstructPix2Pix, Text2Image, VisualQuestionAnswering, ) if torch.cuda.is_available(): toolsets.extend( [ Text2Image("cuda"), ImageEditing("cuda"), InstructPix2Pix("cuda"), VisualQuestionAnswering("cuda"), ] ) handlers[FileType.IMAGE] = ImageCaptioning("cuda") agent_manager = AgentManager.create(toolsets=toolsets) file_handler = FileHandler(handlers=handlers) class Request(BaseModel): key: str query: str files: List[str] class Response(TypedDict): response: str files: List[str] @app.get("/") async def index(): return {"message": f"Hello World. I'm {settings['BOT_NAME']}."} @app.post("/command") async def command(request: Request) -> Response: query = request.query files = request.files session = request.key executor = agent_manager.get_or_create_executor(session) promptedQuery = "\n".join([file_handler.handle(file) for file in files]) promptedQuery += query try: res = executor({"input": promptedQuery}) except Exception as e: return {"response": str(e), "files": []} files = re.findall("image/\S*png|dataframe/\S*csv", res["output"]) return { "response": res["output"], "files": [uploader.upload(file) for file in files], } def serve(): uvicorn.run("api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"])