diff --git a/.gitignore b/.gitignore index 5723a43..4b7e183 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,6 @@ audio/ video/ dataframe/ -static/* +static/generated playground/ \ No newline at end of file diff --git a/README.md b/README.md index 98f2576..7ca2b14 100644 --- a/README.md +++ b/README.md @@ -111,20 +111,20 @@ Some tools requires environment variables. Set envs depend on which tools you wa ### 3. Send request to EVAL -- `POST /command` +- `POST /api/execute` - - `key` - session id + - `session` - session id - `files` - urls of file inputs - - `query` - prompt + - `prompt` - prompt - You can send request to EVAL with `curl` or `httpie`. ```bash - curl -X POST -H "Content-Type: application/json" -d '{"key": "sessionid", "files": ["https://example.com/image.png"], "query": "Hi there!"}' http://localhost:8000/command + curl -X POST -H "Content-Type: application/json" -d '{"session": "sessionid", "files": ["https://example.com/image.png"], "prompt": "Hi there!"}' http://localhost:8000/command ``` ```bash - http POST http://localhost:8000/command key=sessionid files:='["https://example.com/image.png"]' query="Hi there!" + http POST http://localhost:8000/command session=sessionid files:='["https://example.com/image.png"]' prompt="Hi there!" ``` - We are planning to make a GUI for EVAL so you can use it without terminal. diff --git a/api/main.py b/api/main.py index 717ad67..8cd667f 100644 --- a/api/main.py +++ b/api/main.py @@ -1,10 +1,15 @@ import os import re +from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Dict, List, TypedDict import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request, UploadFile +from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates + from pydantic import BaseModel from core.agents.manager import AgentManager @@ -19,9 +24,16 @@ from env import settings app = FastAPI() -app.mount("/static", StaticFiles(directory=StaticUploader.STATIC_DIR), name="static") -uploader = StaticUploader.from_settings(settings) -os.chdir(settings["PLAYGROUND_DIR"]) + +BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"]) + +uploader = StaticUploader.from_settings( + settings, path=BASE_DIR / "static", endpoint="static" +) +app.mount("/static", StaticFiles(directory=uploader.path), name="static") + +templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") toolsets: List[BaseToolSet] = [ Terminal(), @@ -54,30 +66,47 @@ if settings["USE_GPU"]: handlers[FileType.IMAGE] = ImageCaptioning("cuda") agent_manager = AgentManager.create(toolsets=toolsets) -file_handler = FileHandler(handlers=handlers) +file_handler = FileHandler(handlers=handlers, path=BASE_DIR) -class Request(BaseModel): - key: str - query: str +class ExecuteRequest(BaseModel): + session: str + prompt: str files: List[str] -class Response(TypedDict): - response: str +class ExecuteResponse(TypedDict): + answer: str files: List[str] -@app.get("/") -async def index(): - return {"message": f"Hello World. I'm {settings['BOT_NAME']}."} +@app.get("/", response_class=HTMLResponse) +async def index(request: Request): + return templates.TemplateResponse("index.html", {"request": request}) + + +@app.get("/dashboard", response_class=HTMLResponse) +async def dashboard(request: Request): + return templates.TemplateResponse("dashboard.html", {"request": request}) -@app.post("/command") -async def command(request: Request) -> Response: - query = request.query +@app.post("/upload") +async def create_upload_file(files: List[UploadFile]): + urls = [] + for file in files: + extension = "." + file.filename.split(".")[-1] + with NamedTemporaryFile(suffix=extension) as tmp_file: + tmp_file.write(file.file.read()) + tmp_file.flush() + urls.append(uploader.upload(tmp_file.name)) + return {"urls": urls} + + +@app.post("/api/execute") +async def execute(request: ExecuteRequest) -> ExecuteResponse: + query = request.prompt files = request.files - session = request.key + session = request.session executor = agent_manager.get_or_create_executor(session) @@ -87,15 +116,26 @@ async def command(request: Request) -> Response: try: res = executor({"input": promptedQuery}) except Exception as e: - return {"response": str(e), "files": []} + return {"answer": str(e), "files": []} - files = re.findall("[file/\S*]", res["output"]) + files = re.findall(r"\[file/\S*\]", res["output"]) + files = [file[1:-1] for file in files] return { - "response": res["output"], + "answer": res["output"], "files": [uploader.upload(file) for file in files], } def serve(): uvicorn.run("api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"]) + + +def dev(): + uvicorn.run( + "api.main:app", + host="0.0.0.0", + port=settings["EVAL_PORT"], + reload=True, + reload_dirs=[BASE_DIR / "core", BASE_DIR / "api"], + ) diff --git a/api/templates/base.html b/api/templates/base.html new file mode 100644 index 0000000..6f1d847 --- /dev/null +++ b/api/templates/base.html @@ -0,0 +1,77 @@ + + + + + + EVAL {% block title %}{% endblock %} + + + {% block head %} {% endblock %} + + +
+ + + + + EVAL + + +
+
+
+ +
+
+ +
+
{% block content %}{% endblock %}
+
+
+ + + + diff --git a/api/templates/dashboard.html b/api/templates/dashboard.html new file mode 100644 index 0000000..a58d3aa --- /dev/null +++ b/api/templates/dashboard.html @@ -0,0 +1,3 @@ +{% extends "base.html" %} {% block head %} {% endblock %} {% block content %} +
Work in progress.
+{% endblock %} diff --git a/api/templates/index.html b/api/templates/index.html new file mode 100644 index 0000000..a20d354 --- /dev/null +++ b/api/templates/index.html @@ -0,0 +1,36 @@ +{% extends "base.html" %} {% block head %} + +{% endblock %} {% block content %} +
+
+
+
+
+ + +
+
+ + +
+
+ + +
+ +
+
+
+
+
+
+
+
+
+ +
+{% endblock %} diff --git a/core/handlers/base.py b/core/handlers/base.py index a1020f7..497ea6e 100644 --- a/core/handlers/base.py +++ b/core/handlers/base.py @@ -1,10 +1,13 @@ import os import uuid +import shutil +from pathlib import Path from enum import Enum from typing import Dict - import requests +from env import settings + class FileType(Enum): IMAGE = "image" @@ -51,8 +54,9 @@ class BaseHandler: class FileHandler: - def __init__(self, handlers: Dict[FileType, BaseHandler]): + def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path): self.handlers = handlers + self.path = path def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler": self.handlers[filetype] = handler @@ -72,6 +76,14 @@ class FileHandler: def handle(self, url: str) -> str: try: - return self.handlers[FileType.from_url(url)].handle(self.download(url)) + if url.startswith(settings["SERVER"]): + local_filename = url[len(settings["SERVER"]) + 1 :] + src = self.path / local_filename + dst = self.path / settings["PLAYGROUND_DIR"] / local_filename + os.makedirs(os.path.dirname(dst), exist_ok=True) + shutil.copy(src, dst) + else: + local_filename = self.download(url) + return self.handlers[FileType.from_url(url)].handle(local_filename) except Exception as e: return "Error: " + str(e) diff --git a/core/prompts/input.py b/core/prompts/input.py index 22dc030..8e6e128 100644 --- a/core/prompts/input.py +++ b/core/prompts/input.py @@ -37,7 +37,7 @@ EVAL_SUFFIX = """TOOLS {bot_name} can ask the user to use tools to look up information that may be helpful in answering the users original question. You are very strict to the filename correctness and will never fake a file name if it does not exist. You will remember to provide the file name loyally if it's provided in the last tool observation. -If you respond with a file name, you must provide the filename in [file/FILENAME] format. +If you have to include files in your response, you must move the files into file/ directory and provide the filename in [file/FILENAME] format. The tools the human can use are: diff --git a/core/upload/static.py b/core/upload/static.py index 1f35414..247c9ad 100644 --- a/core/upload/static.py +++ b/core/upload/static.py @@ -1,26 +1,28 @@ import os import shutil - +from pathlib import Path from env import DotEnv from .base import AbstractUploader class StaticUploader(AbstractUploader): - STATIC_DIR = "static" - - def __init__(self, server: str): + def __init__(self, server: str, path: Path, endpoint: str): self.server = server + self.path = path + self.endpoint = endpoint @staticmethod - def from_settings(settings: DotEnv) -> "StaticUploader": - return StaticUploader(settings["SERVER"]) + def from_settings(settings: DotEnv, path: Path, endpoint: str) -> "StaticUploader": + return StaticUploader(settings["SERVER"], path, endpoint) def get_url(self, uploaded_path: str) -> str: return f"{self.server}/{uploaded_path}" def upload(self, filepath: str): - upload_path = os.path.join(StaticUploader.STATIC_DIR, filepath) - os.makedirs(os.path.dirname(upload_path), exist_ok=True) - shutil.copy(filepath, upload_path) - return f"{self.server}/{upload_path}" + relative_path = Path("generated") / filepath.split("/")[-1] + file_path = self.path / relative_path + os.makedirs(os.path.dirname(file_path), exist_ok=True) + shutil.copy(filepath, file_path) + endpoint_path = self.endpoint / relative_path + return f"{self.server}/{endpoint_path}" diff --git a/poetry.lock b/poetry.lock index cdce425..492449a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -818,7 +818,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -905,7 +905,7 @@ tiktoken = "*" name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1719,6 +1719,21 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.6" +description = "A streaming multipart parser for Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"}, + {file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"}, +] + +[package.extras] +dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"] + [[package]] name = "python-ptrace" version = "0.9.8" @@ -2618,4 +2633,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "75263beb9c26f5ce6dc0e25851441272e3095880ebf0a01e2b9bf9be5a8ab10e" +content-hash = "e6e707098fa68cc228ba4bd454533f460f354991091144c0e3bde29ae0b409e1" diff --git a/pyproject.toml b/pyproject.toml index ff530d4..09344b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ packages = [{include = "api"}] [tool.poetry.scripts] serve = "api.main:serve" +dev = "api.main:dev" [tool.poetry.dependencies] python = "^3.10" @@ -21,6 +22,8 @@ pillow = "^9.4.0" boto3 = "^1.26.94" uvicorn = "^0.21.1" python-ptrace = "^0.9.8" +jinja2 = "^3.1.2" +python-multipart = "^0.0.6" [tool.poetry.group.gpu] optional = true diff --git a/static/execute.js b/static/execute.js new file mode 100644 index 0000000..74ced47 --- /dev/null +++ b/static/execute.js @@ -0,0 +1,61 @@ +const setAnswer = (answer, files) => { + document.getElementById("answer").textContent = answer; + const filesDiv = document.getElementById("response-files"); + filesDiv.innerHTML = ""; + files.forEach((file) => { + const a = document.createElement("a"); + a.classList.add("icon-link"); + a.href = file; + a.textContent = file.split("/").pop(); + a.setAttribute("download", ""); + filesDiv.appendChild(a); + }); +}; + +const submit = async () => { + setAnswer("Loading...", []); + const files = []; + const rawfiles = document.getElementById("files").files; + + if (rawfiles.length > 0) { + const formData = new FormData(); + for (let i = 0; i < rawfiles.length; i++) { + formData.append("files", rawfiles[i]); + } + const respone = await fetch("/upload", { + method: "POST", + body: formData, + }); + const { urls } = await respone.json(); + files.push(...urls); + } + + const prompt = document.getElementById("prompt").value; + const session = document.getElementById("session").value; + + try { + const response = await fetch("/api/execute", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt, + session, + files, + }), + }); + if (response.status !== 200) { + throw new Error(await response.text()); + } + const { answer, files: responseFiles } = await response.json(); + setAnswer(answer, responseFiles); + } catch (e) { + setAnswer("Error: " + e.message, []); + } +}; + +const setRandomSessionId = () => { + const sessionId = Math.random().toString(36).substring(2, 15); + document.getElementById("session").value = sessionId; +}; diff --git a/static/layout.js b/static/layout.js new file mode 100644 index 0000000..516d385 --- /dev/null +++ b/static/layout.js @@ -0,0 +1,11 @@ +const highlightActiveNavItem = () => { + const navItems = document.querySelectorAll("#nav-sidebar > li > a"); + const currentPath = window.location.pathname; + navItems.forEach((item) => { + if (item.getAttribute("href") === currentPath) { + item.classList.add("active"); + } + }); +}; + +highlightActiveNavItem(); diff --git a/static/.gitkeep b/static/styles.css similarity index 100% rename from static/.gitkeep rename to static/styles.css