2023-03-18 12:42:14 +00:00
|
|
|
from typing import Dict, List, TypedDict
|
2023-03-17 15:55:15 +00:00
|
|
|
import re
|
2023-03-23 07:33:45 +00:00
|
|
|
import uvicorn
|
2023-03-17 15:55:15 +00:00
|
|
|
|
2023-03-26 06:28:14 +00:00
|
|
|
import torch
|
2023-03-17 15:55:15 +00:00
|
|
|
from fastapi import FastAPI
|
2023-03-23 07:33:45 +00:00
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
2023-03-17 15:55:15 +00:00
|
|
|
from pydantic import BaseModel
|
2023-03-20 08:27:20 +00:00
|
|
|
|
2023-03-18 12:42:14 +00:00
|
|
|
from env import settings
|
2023-03-17 15:55:15 +00:00
|
|
|
|
2023-03-23 07:33:45 +00:00
|
|
|
from core.prompts.error import ERROR_PROMPT
|
|
|
|
from core.agents.manager import AgentManager
|
|
|
|
from core.tools.base import BaseToolSet
|
2023-03-30 06:37:56 +00:00
|
|
|
from core.tools.terminal import Terminal
|
|
|
|
from core.tools.editor import CodeEditor
|
2023-03-23 07:33:45 +00:00
|
|
|
from core.tools.cpu import (
|
2023-03-18 12:42:14 +00:00
|
|
|
RequestsGet,
|
|
|
|
WineDB,
|
2023-03-22 03:21:29 +00:00
|
|
|
ExitConversation,
|
2023-03-18 12:42:14 +00:00
|
|
|
)
|
2023-03-23 07:33:45 +00:00
|
|
|
from core.tools.gpu import (
|
2023-03-18 12:42:14 +00:00
|
|
|
ImageEditing,
|
|
|
|
InstructPix2Pix,
|
|
|
|
Text2Image,
|
|
|
|
VisualQuestionAnswering,
|
|
|
|
)
|
2023-03-23 07:33:45 +00:00
|
|
|
from core.handlers.base import BaseHandler, FileHandler, FileType
|
|
|
|
from core.handlers.image import ImageCaptioning
|
|
|
|
from core.handlers.dataframe import CsvToDataframe
|
|
|
|
from core.upload import StaticUploader
|
|
|
|
|
2023-03-22 00:34:52 +00:00
|
|
|
from logger import logger
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
app = FastAPI()
|
2023-03-18 12:42:14 +00:00
|
|
|
|
2023-03-23 07:33:45 +00:00
|
|
|
app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static")
|
|
|
|
uploader = StaticUploader.from_settings(settings)
|
2023-03-21 12:20:15 +00:00
|
|
|
|
2023-03-30 06:37:56 +00:00
|
|
|
use_gpu = settings["USE_GPU"] and torch.cuda.is_available()
|
|
|
|
|
|
|
|
toolsets: List[BaseToolSet] = [
|
|
|
|
Terminal(),
|
|
|
|
CodeEditor(),
|
|
|
|
RequestsGet(),
|
|
|
|
ExitConversation(),
|
|
|
|
]
|
|
|
|
|
|
|
|
if use_gpu:
|
|
|
|
toolsets.extend(
|
|
|
|
[
|
|
|
|
Text2Image("cuda"),
|
|
|
|
ImageEditing("cuda"),
|
|
|
|
InstructPix2Pix("cuda"),
|
|
|
|
VisualQuestionAnswering("cuda"),
|
|
|
|
]
|
|
|
|
)
|
2023-03-26 06:28:14 +00:00
|
|
|
|
|
|
|
handlers: Dict[FileType, BaseHandler] = {}
|
|
|
|
handlers[FileType.DATAFRAME] = CsvToDataframe()
|
2023-03-30 06:37:56 +00:00
|
|
|
if use_gpu:
|
2023-03-26 06:28:14 +00:00
|
|
|
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
|
2023-03-18 12:42:14 +00:00
|
|
|
|
|
|
|
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
|
|
|
|
toolsets.append(WineDB())
|
|
|
|
|
2023-03-21 12:20:15 +00:00
|
|
|
agent_manager = AgentManager.create(toolsets=toolsets)
|
|
|
|
file_handler = FileHandler(handlers=handlers)
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Request(BaseModel):
|
|
|
|
key: str
|
2023-03-18 06:05:02 +00:00
|
|
|
query: str
|
|
|
|
files: List[str]
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Response(TypedDict):
|
|
|
|
response: str
|
2023-03-18 06:05:02 +00:00
|
|
|
files: List[str]
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
async def index():
|
2023-03-20 08:27:20 +00:00
|
|
|
return {"message": f"Hello World. I'm {settings['BOT_NAME']}."}
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
@app.post("/command")
|
|
|
|
async def command(request: Request) -> Response:
|
2023-03-18 06:05:02 +00:00
|
|
|
query = request.query
|
2023-03-17 15:55:15 +00:00
|
|
|
files = request.files
|
2023-03-22 00:57:38 +00:00
|
|
|
session = request.key
|
2023-03-17 15:55:15 +00:00
|
|
|
|
2023-03-22 00:34:52 +00:00
|
|
|
logger.info("=============== Running =============")
|
|
|
|
logger.info(f"Query: {query}, Files: {files}")
|
2023-03-22 00:57:38 +00:00
|
|
|
executor = agent_manager.get_or_create_executor(session)
|
2023-03-21 12:20:15 +00:00
|
|
|
|
2023-03-22 00:34:52 +00:00
|
|
|
logger.info(f"======> Previous memory:\n\t{executor.memory}")
|
2023-03-17 15:55:15 +00:00
|
|
|
|
2023-03-21 12:20:15 +00:00
|
|
|
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
|
2023-03-18 06:05:02 +00:00
|
|
|
promptedQuery += query
|
2023-03-22 00:34:52 +00:00
|
|
|
logger.info(f"======> Prompted Text:\n\t{promptedQuery}")
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
try:
|
2023-03-21 12:20:15 +00:00
|
|
|
res = executor({"input": promptedQuery})
|
2023-03-17 15:55:15 +00:00
|
|
|
except Exception as e:
|
2023-03-30 06:37:56 +00:00
|
|
|
logger.error(f"error while processing request: {str(e)}")
|
2023-03-17 15:55:15 +00:00
|
|
|
try:
|
2023-03-21 12:20:15 +00:00
|
|
|
res = executor(
|
2023-03-17 15:55:15 +00:00
|
|
|
{
|
2023-03-18 06:05:02 +00:00
|
|
|
"input": ERROR_PROMPT.format(promptedQuery=promptedQuery, e=str(e)),
|
2023-03-17 15:55:15 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
except Exception as e:
|
2023-03-18 06:05:02 +00:00
|
|
|
return {"response": str(e), "files": []}
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
images = re.findall("(image/\S*png)", res["output"])
|
2023-03-18 06:05:02 +00:00
|
|
|
dataframes = re.findall("(dataframe/\S*csv)", res["output"])
|
2023-03-17 15:55:15 +00:00
|
|
|
|
|
|
|
return {
|
|
|
|
"response": res["output"],
|
2023-03-23 07:33:45 +00:00
|
|
|
"files": [uploader.upload(image) for image in images]
|
|
|
|
+ [uploader.upload(dataframe) for dataframe in dataframes],
|
2023-03-17 15:55:15 +00:00
|
|
|
}
|
2023-03-23 07:33:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def serve():
|
|
|
|
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["PORT"])
|