diff --git a/README.md b/README.md index 1765597..78ebe22 100644 --- a/README.md +++ b/README.md @@ -111,23 +111,32 @@ Some tools requires environment variables. Set envs depend on which tools you wa ### 3. Send request to EVAL -- `POST /api/execute` +- Use the Web GUI to use EVAL in ease - - `session` - session id - - `files` - urls of file inputs - - `prompt` - prompt + - Go to `http://localhost:8000` in your browser + -- You can send request to EVAL with `curl` or `httpie`. +- Or you can manually send request to EVAL with APIs. - ```bash - curl -X POST -H "Content-Type: application/json" -d '{"session": "sessionid", "files": [], "prompt": "Hi there!"}' http://localhost:8000/api/execute - ``` + - `POST /api/execute` - ```bash - http POST http://localhost:8000/api/execute session=sessionid files:='[]' prompt="Hi there!" - ``` + - `session` - session id + - `files` - urls of file inputs + - `prompt` - prompt -- We are planning to make a GUI for EVAL so you can use it without terminal. + - examples + + ```bash + curl -X POST -H "Content-Type: application/json" -d '{"session": "sessionid", "files": [], "prompt": "Hi there!"}' http://localhost:8000/api/execute + ``` + + ```bash + http POST http://localhost:8000/api/execute session=sessionid files:='[]' prompt="Hi there!" + ``` + +- It also supports asynchronous execution. You can use `POST /api/execute/async` instead of `POST /api/execute`, with same body. + + - It returns `id` of the execution. Use `GET /api/execute/async/{id}` to get the result. ## TODO diff --git a/api/container.py b/api/container.py new file mode 100644 index 0000000..028c608 --- /dev/null +++ b/api/container.py @@ -0,0 +1,62 @@ +import os +import re +from pathlib import Path +from typing import Dict, List + +from fastapi.templating import Jinja2Templates + +from core.agents.manager import AgentManager +from core.handlers.base import BaseHandler, FileHandler, FileType +from core.handlers.dataframe import CsvToDataframe +from core.tools.base import BaseToolSet +from core.tools.cpu import ExitConversation, RequestsGet +from core.tools.editor import CodeEditor +from core.tools.terminal import Terminal +from core.upload import StaticUploader +from env import settings + +BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"]) + + +toolsets: List[BaseToolSet] = [ + Terminal(), + CodeEditor(), + RequestsGet(), + ExitConversation(), +] +handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} + +if settings["USE_GPU"]: + import torch + + from core.handlers.image import ImageCaptioning + from core.tools.gpu import ( + ImageEditing, + InstructPix2Pix, + Text2Image, + VisualQuestionAnswering, + ) + + if torch.cuda.is_available(): + toolsets.extend( + [ + Text2Image("cuda"), + ImageEditing("cuda"), + InstructPix2Pix("cuda"), + VisualQuestionAnswering("cuda"), + ] + ) + handlers[FileType.IMAGE] = ImageCaptioning("cuda") + +agent_manager = AgentManager.create(toolsets=toolsets) + +file_handler = FileHandler(handlers=handlers, path=BASE_DIR) + +templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") + +uploader = StaticUploader.from_settings( + settings, path=BASE_DIR / "static", endpoint="static" +) + +reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"] diff --git a/api/main.py b/api/main.py index 5126f80..898207b 100644 --- a/api/main.py +++ b/api/main.py @@ -1,72 +1,22 @@ -import os import re -from pathlib import Path +from multiprocessing import Process from tempfile import NamedTemporaryFile -from typing import Dict, List, TypedDict +from typing import List, TypedDict import uvicorn from fastapi import FastAPI, Request, UploadFile from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates from pydantic import BaseModel -from core.agents.manager import AgentManager -from core.handlers.base import BaseHandler, FileHandler, FileType -from core.handlers.dataframe import CsvToDataframe -from core.tools.base import BaseToolSet -from core.tools.cpu import ExitConversation, RequestsGet -from core.tools.editor import CodeEditor -from core.tools.terminal import Terminal -from core.upload import StaticUploader +from api.container import agent_manager, file_handler, reload_dirs, templates, uploader +from api.worker import get_task_result, start_worker, task_execute from env import settings app = FastAPI() - -BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"]) - -uploader = StaticUploader.from_settings( - settings, path=BASE_DIR / "static", endpoint="static" -) app.mount("/static", StaticFiles(directory=uploader.path), name="static") -templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") - -toolsets: List[BaseToolSet] = [ - Terminal(), - CodeEditor(), - RequestsGet(), - ExitConversation(), -] -handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} - -if settings["USE_GPU"]: - import torch - - from core.handlers.image import ImageCaptioning - from core.tools.gpu import ( - ImageEditing, - InstructPix2Pix, - Text2Image, - VisualQuestionAnswering, - ) - - if torch.cuda.is_available(): - toolsets.extend( - [ - Text2Image("cuda"), - ImageEditing("cuda"), - InstructPix2Pix("cuda"), - VisualQuestionAnswering("cuda"), - ] - ) - handlers[FileType.IMAGE] = ImageCaptioning("cuda") - -agent_manager = AgentManager.create(toolsets=toolsets) -file_handler = FileHandler(handlers=handlers, path=BASE_DIR) - class ExecuteRequest(BaseModel): session: str @@ -107,7 +57,7 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse: files = request.files session = request.session - executor = agent_manager.get_or_create_executor(session) + executor = agent_manager.create_executor(session) promptedQuery = "\n".join([file_handler.handle(file) for file in files]) promptedQuery += query @@ -117,7 +67,7 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse: except Exception as e: return {"answer": str(e), "files": []} - files = re.findall(r"\[file/\S*\]", res["output"]) + files = re.findall(r"\[file://\S*\]", res["output"]) files = [file[1:-1] for file in files] return { @@ -126,15 +76,53 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse: } +@app.post("/api/execute/async") +async def execute_async(request: ExecuteRequest): + query = request.prompt + files = request.files + session = request.session + + promptedQuery = "\n".join([file_handler.handle(file) for file in files]) + promptedQuery += query + + execution = task_execute.delay(session, promptedQuery) + return {"id": execution.id} + + +@app.get("/api/execute/async/{execution_id}") +async def execute_async(execution_id: str): + execution = get_task_result(execution_id) + + result = {} + if execution.status == "SUCCESS" and execution.result: + output = execution.result.get("output", "") + files = re.findall(r"\[file://\S*\]", output) + files = [file[1:-1] for file in files] + result = { + "answer": output, + "files": [uploader.upload(file) for file in files], + } + + return { + "status": execution.status, + "info": execution.info, + "result": result, + } + + def serve(): + p = Process(target=start_worker, args=[]) + p.start() uvicorn.run("api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"]) def dev(): + p = Process(target=start_worker, args=[]) + p.start() uvicorn.run( "api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"], reload=True, - reload_dirs=[BASE_DIR / "core", BASE_DIR / "api"], + reload_dirs=reload_dirs, ) diff --git a/api/templates/base.html b/api/templates/base.html index d35ced7..b8f7fd6 100644 --- a/api/templates/base.html +++ b/api/templates/base.html @@ -44,9 +44,9 @@
-

-
+
-->
{% block content %}{% endblock %}
diff --git a/api/templates/index.html b/api/templates/index.html index a20d354..5c6b187 100644 --- a/api/templates/index.html +++ b/api/templates/index.html @@ -17,15 +17,32 @@
- + -
-
-
-
+
+
+
+
+
+
+
+
diff --git a/api/worker.py b/api/worker.py new file mode 100644 index 0000000..44062d3 --- /dev/null +++ b/api/worker.py @@ -0,0 +1,42 @@ +from celery import Celery +from celery.result import AsyncResult + +from api.container import agent_manager +from env import settings + +celery_app = Celery(__name__) +celery_app.conf.broker_url = settings["CELERY_BROKER_URL"] +celery_app.conf.result_backend = settings["CELERY_BROKER_URL"] +celery_app.conf.update( + task_track_started=True, + task_serializer="json", + accept_content=["json"], # Ignore other content + result_serializer="json", + enable_utc=True, +) + + +@celery_app.task(name="task_execute", bind=True) +def task_execute(self, session: str, prompt: str): + executor = agent_manager.create_executor(session, self) + response = executor({"input": prompt}) + result = {"output": response["output"]} + + previous = AsyncResult(self.request.id) + if previous and previous.info: + result.update(previous.info) + + return result + + +def get_task_result(task_id): + return AsyncResult(task_id) + + +def start_worker(): + celery_app.worker_main( + [ + "worker", + "--loglevel=INFO", + ] + ) diff --git a/assets/gui.png b/assets/gui.png new file mode 100644 index 0000000..129157f Binary files /dev/null and b/assets/gui.png differ diff --git a/core/agents/builder.py b/core/agents/builder.py index 1f45aee..10cb5e9 100644 --- a/core/agents/builder.py +++ b/core/agents/builder.py @@ -4,6 +4,7 @@ from core.tools.factory import ToolsFactory from env import settings from langchain.chat_models.base import BaseChatModel from langchain.output_parsers.base import BaseOutputParser +from langchain.callbacks.base import BaseCallbackManager from .chat_agent import ConversationalChatAgent from .llm import ChatOpenAI @@ -17,8 +18,10 @@ class AgentBuilder: self.global_tools: list = None self.toolsets = toolsets - def build_llm(self): - self.llm = ChatOpenAI(temperature=0) + def build_llm(self, callback_manager: BaseCallbackManager = None): + self.llm = ChatOpenAI( + temperature=0, callback_manager=callback_manager, verbose=True + ) self.llm.check_access() def build_parser(self): @@ -40,6 +43,12 @@ class AgentBuilder: *ToolsFactory.create_global_tools(self.toolsets), ] + def get_parser(self): + if self.parser is None: + raise ValueError("Parser is not initialized yet") + + return self.parser + def get_global_tools(self): if self.global_tools is None: raise ValueError("Global tools are not initialized yet") diff --git a/core/agents/callback.py b/core/agents/callback.py index 6769a14..a52e9d1 100644 --- a/core/agents/callback.py +++ b/core/agents/callback.py @@ -2,22 +2,45 @@ from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult +from celery import Task from ansi import ANSI, Color, Style, dim_multiline from logger import logger class EVALCallbackHandler(BaseCallbackHandler): + @property + def ignore_llm(self) -> bool: + return False + + def set_parser(self, parser) -> None: + self.parser = parser + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - pass + text = response.generations[0][0].text + + parsed = self.parser.parse_all(text) + + logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + parsed["plan"]) + logger.info(ANSI("What I Did").to(Color.blue()) + ": " + parsed["what_i_did"]) + logger.info( + ANSI("Action").to(Color.cyan()) + + ": " + + ANSI(parsed["action"]).to(Style.bold()) + ) + logger.info( + ANSI("Input").to(Color.cyan()) + + ": " + + dim_multiline(parsed["action_input"]) + ) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - pass + logger.info(ANSI(f"on_llm_new_token {token}").to(Color.green(), Style.italic())) def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @@ -85,3 +108,90 @@ class EVALCallbackHandler(BaseCallbackHandler): + ": " + dim_multiline(finish.return_values.get("output", "")) ) + + +class ExecutionTracingCallbackHandler(BaseCallbackHandler): + def __init__(self, execution: Task): + self.execution = execution + self.index = 0 + + def set_parser(self, parser) -> None: + self.parser = parser + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + text = response.generations[0][0].text + parsed = self.parser.parse_all(text) + self.index += 1 + parsed["index"] = self.index + self.execution.update_state(state="LLM_END", meta=parsed) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + pass + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + pass + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + self.execution.update_state(state="CHAIN_ERROR", meta={"error": str(error)}) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + pass + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + pass + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + previous = self.execution.AsyncResult(self.execution.request.id) + self.execution.update_state( + state="TOOL_END", meta={**previous.info, "observation": output} + ) + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + previous = self.execution.AsyncResult(self.execution.request.id) + self.execution.update_state( + state="TOOL_ERROR", meta={**previous.info, "error": str(error)} + ) + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + pass + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + pass diff --git a/core/agents/manager.py b/core/agents/manager.py index 4f7ebf4..e11d5f7 100644 --- a/core/agents/manager.py +++ b/core/agents/manager.py @@ -1,9 +1,9 @@ -from typing import Dict +from typing import Dict, Optional +from celery import Task -from langchain.agents.agent import Agent, AgentExecutor -from langchain.agents.tools import BaseTool -from langchain.callbacks import set_handler +from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import CallbackManager +from langchain.callbacks import set_handler from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.memory.chat_memory import BaseChatMemory @@ -11,68 +11,73 @@ from core.tools.base import BaseToolSet from core.tools.factory import ToolsFactory from .builder import AgentBuilder -from .callback import EVALCallbackHandler +from .callback import EVALCallbackHandler, ExecutionTracingCallbackHandler + -callback_manager = CallbackManager([EVALCallbackHandler()]) set_handler(EVALCallbackHandler()) class AgentManager: def __init__( self, - agent: Agent, - global_tools: list[BaseTool], toolsets: list[BaseToolSet] = [], ): - self.agent: Agent = agent - self.global_tools: list[BaseTool] = global_tools self.toolsets: list[BaseToolSet] = toolsets + self.memories: Dict[str, BaseChatMemory] = {} self.executors: Dict[str, AgentExecutor] = {} def create_memory(self) -> BaseChatMemory: return ConversationBufferMemory(memory_key="chat_history", return_messages=True) - def create_executor(self, session: str) -> AgentExecutor: - memory: BaseChatMemory = self.create_memory() + def get_or_create_memory(self, session: str) -> BaseChatMemory: + if not (session in self.memories): + self.memories[session] = self.create_memory() + return self.memories[session] + + def create_executor( + self, session: str, execution: Optional[Task] = None + ) -> AgentExecutor: + builder = AgentBuilder(self.toolsets) + builder.build_parser() + + callbacks = [] + eval_callback = EVALCallbackHandler() + eval_callback.set_parser(builder.get_parser()) + callbacks.append(eval_callback) + if execution: + execution_callback = ExecutionTracingCallbackHandler(execution) + execution_callback.set_parser(builder.get_parser()) + callbacks.append(execution_callback) + + callback_manager = CallbackManager(callbacks) + + builder.build_llm(callback_manager) + builder.build_global_tools() + + memory: BaseChatMemory = self.get_or_create_memory(session) tools = [ - *self.global_tools, + *builder.get_global_tools(), *ToolsFactory.create_per_session_tools( self.toolsets, get_session=lambda: (session, self.executors[session]), ), ] + for tool in tools: - tool.set_callback_manager(callback_manager) + tool.callback_manager = callback_manager - return AgentExecutor.from_agent_and_tools( - agent=self.agent, + executor = AgentExecutor.from_agent_and_tools( + agent=builder.get_agent(), tools=tools, memory=memory, callback_manager=callback_manager, verbose=True, ) - - def remove_executor(self, session: str) -> None: - if session in self.executors: - del self.executors[session] - - def get_or_create_executor(self, session: str) -> AgentExecutor: - if not (session in self.executors): - self.executors[session] = self.create_executor(session=session) - return self.executors[session] + self.executors[session] = executor + return executor @staticmethod def create(toolsets: list[BaseToolSet]) -> "AgentManager": - builder = AgentBuilder(toolsets) - builder.build_llm() - builder.build_parser() - builder.build_global_tools() - - agent = builder.get_agent() - global_tools = builder.get_global_tools() - return AgentManager( - agent=agent, - global_tools=global_tools, toolsets=toolsets, ) diff --git a/core/agents/parser.py b/core/agents/parser.py index b0f2714..1c9dc51 100644 --- a/core/agents/parser.py +++ b/core/agents/parser.py @@ -1,16 +1,31 @@ import re -import time from typing import Dict from langchain.output_parsers.base import BaseOutputParser -from ansi import ANSI, Color, Style -from core.agents.callback import dim_multiline from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS -from logger import logger class EvalOutputParser(BaseOutputParser): + @staticmethod + def parse_all(text: str) -> Dict[str, str]: + regex = r"Action: (.*?)[\n]Plan:(.*)[\n]What I Did:(.*)[\n]Action Input: (.*)" + match = re.search(regex, text, re.DOTALL) + if not match: + raise Exception("parse error") + + action = match.group(1).strip() + plan = match.group(2) + what_i_did = match.group(3) + action_input = match.group(4).strip(" ").strip('"') + + return { + "action": action, + "plan": plan, + "what_i_did": what_i_did, + "action_input": action_input, + } + def get_format_instructions(self) -> str: return EVAL_FORMAT_INSTRUCTIONS @@ -20,21 +35,9 @@ class EvalOutputParser(BaseOutputParser): if not match: raise Exception("parse error") - action = match.group(1).strip() - plan = match.group(2) - what_i_did = match.group(3) - action_input = match.group(4) - - logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + plan) - logger.info(ANSI("What I Did").to(Color.blue()) + ": " + what_i_did) - time.sleep(1) - logger.info( - ANSI("Action").to(Color.cyan()) + ": " + ANSI(action).to(Style.bold()) - ) - time.sleep(1) - logger.info(ANSI("Input").to(Color.cyan()) + ": " + dim_multiline(action_input)) - time.sleep(1) - return {"action": action, "action_input": action_input.strip(" ").strip('"')} + parsed = EvalOutputParser.parse_all(text) + + return {"action": parsed["action"], "action_input": parsed["action_input"]} def __str__(self): return "EvalOutputParser" diff --git a/core/prompts/input.py b/core/prompts/input.py index cfd5f9d..47971c9 100644 --- a/core/prompts/input.py +++ b/core/prompts/input.py @@ -3,7 +3,7 @@ EVAL_PREFIX = """{bot_name} can execute any user's request. {bot_name} has permission to handle one instance and can handle the environment in it at will. You can code, run, debug, and test yourself. You can correct the code appropriately by looking at the error message. -I can understand, process, and create various types of files. Every image, dataframe, audio, video must be restored in file/ directory. +I can understand, process, and create various types of files. {bot_name} can do whatever it takes to execute the user's request. Let's think step by step. """ @@ -37,7 +37,7 @@ EVAL_SUFFIX = """TOOLS {bot_name} can ask the user to use tools to look up information that may be helpful in answering the users original question. You are very strict to the filename correctness and will never fake a file name if it does not exist. You will remember to provide the file name loyally if it's provided in the last tool observation. -If you have to include files in your response, you must move the files into file/ directory and provide the filename in [file/FILENAME] format. It must be wrapped in square brackets. +If you have to include files in your response, you must provide the filepath in [file://filepath] format. It must be wrapped in square brackets. The tools the human can use are: diff --git a/docker-compose.yml b/docker-compose.yml index e5a9bfc..22fd174 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,10 @@ services: - "8000:8000" # eval port env_file: - .env + environment: + - CELERY_BROKER_URL=redis://redis:6379 + depends_on: + - redis eval.gpu: container_name: eval.gpu @@ -40,3 +44,8 @@ services: - driver: nvidia device_ids: ["1"] # You can choose which GPU to use capabilities: [gpu] + depends_on: + - redis + + redis: + image: redis:alpine diff --git a/env.py b/env.py index 061c286..2f9561f 100644 --- a/env.py +++ b/env.py @@ -12,6 +12,7 @@ class DotEnv(TypedDict): EVAL_PORT: int SERVER: str + CELERY_BROKER_URL: str USE_GPU: bool # optional PLAYGROUND_DIR: str # optional LOG_LEVEL: str # optional @@ -31,6 +32,7 @@ EVAL_PORT = int(os.getenv("EVAL_PORT", 8000)) settings: DotEnv = { "EVAL_PORT": EVAL_PORT, "MODEL_NAME": os.getenv("MODEL_NAME", "gpt-4"), + "CELERY_BROKER_URL": os.getenv("CELERY_BROKER_URL", "redis://localhost:6379"), "SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"), "USE_GPU": os.getenv("USE_GPU", "False").lower() == "true", "PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"), diff --git a/poetry.lock b/poetry.lock index 492449a..2c84efb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -153,6 +153,21 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "amqp" +version = "5.1.1" +description = "Low-level AMQP client for Python (fork of amqplib)." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"}, + {file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"}, +] + +[package.dependencies] +vine = ">=5.0.0" + [[package]] name = "anyio" version = "3.6.2" @@ -224,6 +239,18 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "billiard" +version = "3.6.4.0" +description = "Python multiprocessing fork with improvements and bugfixes" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "billiard-3.6.4.0-py3-none-any.whl", hash = "sha256:87103ea78fa6ab4d5c751c4909bcff74617d985de7fa8b672cf8618afd5a875b"}, + {file = "billiard-3.6.4.0.tar.gz", hash = "sha256:299de5a8da28a783d51b197d496bef4f1595dd023a93a4f59dde1886ae905547"}, +] + [[package]] name = "bitsandbytes" version = "0.37.2" @@ -325,6 +352,61 @@ urllib3 = ">=1.25.4,<1.27" [package.extras] crt = ["awscrt (==0.16.9)"] +[[package]] +name = "celery" +version = "5.2.7" +description = "Distributed Task Queue." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "celery-5.2.7-py3-none-any.whl", hash = "sha256:138420c020cd58d6707e6257b6beda91fd39af7afde5d36c6334d175302c0e14"}, + {file = "celery-5.2.7.tar.gz", hash = "sha256:fafbd82934d30f8a004f81e8f7a062e31413a23d444be8ee3326553915958c6d"}, +] + +[package.dependencies] +billiard = ">=3.6.4.0,<4.0" +click = ">=8.0.3,<9.0" +click-didyoumean = ">=0.0.3" +click-plugins = ">=1.1.1" +click-repl = ">=0.2.0" +kombu = ">=5.2.3,<6.0" +pytz = ">=2021.3" +vine = ">=5.0.0,<6.0" + +[package.extras] +arangodb = ["pyArango (>=1.3.2)"] +auth = ["cryptography"] +azureblockblob = ["azure-storage-blob (==12.9.0)"] +brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"] +cassandra = ["cassandra-driver (<3.21.0)"] +consul = ["python-consul2"] +cosmosdbsql = ["pydocumentdb (==2.3.2)"] +couchbase = ["couchbase (>=3.0.0)"] +couchdb = ["pycouchdb"] +django = ["Django (>=1.11)"] +dynamodb = ["boto3 (>=1.9.178)"] +elasticsearch = ["elasticsearch"] +eventlet = ["eventlet (>=0.32.0)"] +gevent = ["gevent (>=1.5.0)"] +librabbitmq = ["librabbitmq (>=1.5.0)"] +memcache = ["pylibmc"] +mongodb = ["pymongo[srv] (>=3.11.1)"] +msgpack = ["msgpack"] +pymemcache = ["python-memcached"] +pyro = ["pyro4"] +pytest = ["pytest-celery"] +redis = ["redis (>=3.4.1,!=4.0.0,!=4.0.1)"] +s3 = ["boto3 (>=1.9.125)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +solar = ["ephem"] +sqlalchemy = ["sqlalchemy"] +sqs = ["kombu[sqs]"] +tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=1.3.1)"] +zstd = ["zstandard"] + [[package]] name = "certifi" version = "2022.12.7" @@ -437,6 +519,56 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "click-didyoumean" +version = "0.3.0" +description = "Enables git-like *did-you-mean* feature in click" +category = "main" +optional = false +python-versions = ">=3.6.2,<4.0.0" +files = [ + {file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"}, + {file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"}, +] + +[package.dependencies] +click = ">=7" + +[[package]] +name = "click-plugins" +version = "1.1.1" +description = "An extension module for click to enable registering CLI commands via setuptools entry-points." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"}, + {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"}, +] + +[package.dependencies] +click = ">=4.0" + +[package.extras] +dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"] + +[[package]] +name = "click-repl" +version = "0.2.0" +description = "REPL plugin for Click" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "click-repl-0.2.0.tar.gz", hash = "sha256:cd12f68d745bf6151210790540b4cb064c7b13e571bc64b6957d98d120dacfd8"}, + {file = "click_repl-0.2.0-py3-none-any.whl", hash = "sha256:94b3fbbc9406a236f176e0506524b2937e4b23b6f4c0c0b2a0a83f8a64e9194b"}, +] + +[package.dependencies] +click = "*" +prompt-toolkit = "*" +six = "*" + [[package]] name = "cmake" version = "3.26.1" @@ -844,6 +976,38 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "kombu" +version = "5.2.4" +description = "Messaging library for Python." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "kombu-5.2.4-py3-none-any.whl", hash = "sha256:8b213b24293d3417bcf0d2f5537b7f756079e3ea232a8386dcc89a59fd2361a4"}, + {file = "kombu-5.2.4.tar.gz", hash = "sha256:37cee3ee725f94ea8bb173eaab7c1760203ea53bbebae226328600f9d2799610"}, +] + +[package.dependencies] +amqp = ">=5.0.9,<6.0.0" +vine = "*" + +[package.extras] +azureservicebus = ["azure-servicebus (>=7.0.0)"] +azurestoragequeues = ["azure-storage-queue"] +consul = ["python-consul (>=0.6.0)"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +mongodb = ["pymongo (>=3.3.0,<3.12.1)"] +msgpack = ["msgpack"] +pyro = ["pyro4"] +qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] +redis = ["redis (>=3.4.1,!=4.0.0,!=4.0.1)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +sqlalchemy = ["sqlalchemy"] +sqs = ["boto3 (>=1.9.12)", "pycurl (>=7.44.1,<7.45.0)", "urllib3 (>=1.26.7)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=1.3.1)"] + [[package]] name = "langchain" version = "0.0.115" @@ -1528,6 +1692,21 @@ files = [ docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] +[[package]] +name = "prompt-toolkit" +version = "3.0.38" +description = "Library for building powerful interactive command lines in Python" +category = "main" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.38-py3-none-any.whl", hash = "sha256:45ea77a2f7c60418850331366c81cf6b5b9cf4c7fd34616f733c5427e6abbb1f"}, + {file = "prompt_toolkit-3.0.38.tar.gz", hash = "sha256:23ac5d50538a9a38c8bde05fecb47d0b403ecd0662857a86f886f798563d5b9b"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "psutil" version = "5.9.4" @@ -1808,6 +1987,25 @@ files = [ {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] +[[package]] +name = "redis" +version = "4.5.4" +description = "Python client for Redis database and key-value store" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-4.5.4-py3-none-any.whl", hash = "sha256:2c19e6767c474f2e85167909061d525ed65bea9301c0770bb151e041b7ac89a2"}, + {file = "redis-4.5.4.tar.gz", hash = "sha256:73ec35da4da267d6847e47f68730fdd5f62e2ca69e3ef5885c6a78a9374c3893"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_version <= \"3.11.2\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "regex" version = "2023.3.23" @@ -2411,6 +2609,15 @@ category = "dev" optional = false python-versions = "*" files = [ + {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, + {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, + {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, + {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, + {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, + {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, + {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, + {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, + {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, @@ -2496,6 +2703,30 @@ h11 = ">=0.8" [package.extras] standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] +[[package]] +name = "vine" +version = "5.0.0" +description = "Promises, promises, promises." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"}, + {file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"}, +] + +[[package]] +name = "wcwidth" +version = "0.2.6" +description = "Measures the displayed width of unicode strings in a terminal" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"}, + {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"}, +] + [[package]] name = "wheel" version = "0.40.0" @@ -2633,4 +2864,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e6e707098fa68cc228ba4bd454533f460f354991091144c0e3bde29ae0b409e1" +content-hash = "3a7cbc62858401d620de4eeab75906ea8caa97d780632e58e1d6316681228d13" diff --git a/pyproject.toml b/pyproject.toml index 09344b4..f964cc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ uvicorn = "^0.21.1" python-ptrace = "^0.9.8" jinja2 = "^3.1.2" python-multipart = "^0.0.6" +celery = "^5.2.7" +redis = "^4.5.4" [tool.poetry.group.gpu] optional = true diff --git a/static/execute.js b/static/execute.js index 74ced47..832378b 100644 --- a/static/execute.js +++ b/static/execute.js @@ -1,6 +1,25 @@ -const setAnswer = (answer, files) => { - document.getElementById("answer").textContent = answer; - const filesDiv = document.getElementById("response-files"); +const $ = (selector) => document.querySelector(selector); + +const setLoader = (isLoading) => { + const button = $("#submit"); + const loader = $("#submit-loader"); + if (isLoading) { + button.style.display = "none"; + loader.style.display = "block"; + } else { + button.style.display = "block"; + loader.style.display = "none"; + } +}; + +const setAnswer = (answer, files = []) => { + if (answer) { + $("#answer").textContent = answer; + } else { + $("#answer").innerHTML = createSpinner(); + } + + const filesDiv = $("#response-files"); filesDiv.innerHTML = ""; files.forEach((file) => { const a = document.createElement("a"); @@ -12,50 +31,247 @@ const setAnswer = (answer, files) => { }); }; -const submit = async () => { - setAnswer("Loading...", []); - const files = []; - const rawfiles = document.getElementById("files").files; - - if (rawfiles.length > 0) { - const formData = new FormData(); - for (let i = 0; i < rawfiles.length; i++) { - formData.append("files", rawfiles[i]); +class EvalApi { + constructor({ onComplete, onError, onSettle, onLLMEnd, onToolEnd }) { + this.executionId = null; + this.pollInterval = null; + this.onComplete = (answer, files, info) => { + onComplete(answer, files, info); + onSettle(); + }; + this.onError = (error) => { + onError(error); + onSettle(); + }; + this.onLLMEnd = (info) => { + onLLMEnd(info); + }; + this.onToolEnd = (info) => { + onToolEnd(info); + }; + } + async uploadFiles(rawfiles) { + const files = []; + + if (rawfiles.length > 0) { + const formData = new FormData(); + for (let i = 0; i < rawfiles.length; i++) { + formData.append("files", rawfiles[i]); + } + const respone = await fetch("/upload", { + method: "POST", + body: formData, + }); + const { urls } = await respone.json(); + files.push(...urls); } - const respone = await fetch("/upload", { - method: "POST", - body: formData, - }); - const { urls } = await respone.json(); - files.push(...urls); + + return files; } - const prompt = document.getElementById("prompt").value; - const session = document.getElementById("session").value; - - try { - const response = await fetch("/api/execute", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - prompt, - session, - files, - }), - }); - if (response.status !== 200) { - throw new Error(await response.text()); + async execute(prompt, session, files) { + try { + const response = await fetch("/api/execute/async", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt, + session, + files, + }), + }); + if (response.status !== 200) { + throw new Error(await response.text()); + } + const { id: executionId } = await response.json(); + this.executionId = executionId; + this.pollInterval = setInterval(this.poll.bind(this), 1000); + } catch (e) { + clearInterval(this.pollInterval); + this.onError(e); } - const { answer, files: responseFiles } = await response.json(); - setAnswer(answer, responseFiles); - } catch (e) { - setAnswer("Error: " + e.message, []); } + + async poll() { + try { + const response = await fetch(`/api/execute/async/${this.executionId}`, { + method: "GET", + }); + if (response.status !== 200) { + throw new Error(await response.text()); + } + const { status, result, info } = await response.json(); + switch (status) { + case "PENDING": + break; + case "FAILURE": + throw new Error("Execution failed"); + case "LLM_END": + this.onLLMEnd(info); + break; + case "TOOL_END": + this.onToolEnd(info); + break; + case "SUCCESS": + clearInterval(this.pollInterval); + this.onComplete(result.answer, result.files, info); + break; + } + } catch (e) { + clearInterval(this.pollInterval); + this.onError(e); + } + } +} + +const submit = async () => { + setAnswer(""); + setLoader(true); + + const actions = $("#actions"); + actions.innerHTML = ""; + + let currentActionIndex = 0; + + const onInfo = (info) => { + if (currentActionIndex >= info.index) { + return; + } + currentActionIndex = info.index; + const w = document.createElement("div"); + w.innerHTML = createActionCard( + info.index, + info.action, + info.action_input, + info.what_i_did, + info.plan, + info.observation + ); + actions.appendChild(w); + }; + + const api = new EvalApi({ + onSettle: () => setLoader(false), + onError: (error) => setAnswer(`Error: ${error.message}`, []), + onComplete: (answer, files, info) => { + setAnswer(answer, files); + onInfo(info); + }, + onLLMEnd: onInfo, + onToolEnd: onInfo, + }); + + const prompt = $("#prompt").value; + const session = $("#session").value; + const files = await api.uploadFiles($("#files").files); + + await api.execute(prompt, session, files); }; const setRandomSessionId = () => { const sessionId = Math.random().toString(36).substring(2, 15); - document.getElementById("session").value = sessionId; + $("#session").value = sessionId; }; + +const createSpinner = () => ` +
+
+
+`; + +const createActionCard = ( + index, + action, + input, + whatIdid, + plan, + observation +) => ` +
+
+

+ +

+
+
+ + + ${ + action !== "Final Answer" + ? ` + + + ` + : "" + } + + + + + +
Input
${input}
What I Did
${whatIdid}
+ + + + + + + + + ${plan + .split("- ") + .map((p) => p.trim()) + .filter((p) => p.length > 0) + .map( + (p) => ` + + ${ + p.startsWith("[ ]") + ? ` + ` + : "" + } + ${ + p.startsWith("[x]") + ? ` + ` + : "" + } + ` + ) + .join("")} + +
Plan
${p.replace("[ ]", "")}${p.replace("[x]", "")}
+ + ${ + action !== "Final Answer" + ? ` + + + + + + + + + + +
Observation
+
${observation}
+
` + : "" + } +
+
+
+
`; diff --git a/static/styles.css b/static/styles.css index b609ffa..4afd0f4 100644 --- a/static/styles.css +++ b/static/styles.css @@ -1,7 +1,7 @@ .logo { border-radius: 50%; overflow: hidden; - height: 64px; - width: 64px; + height: 48px; + width: 48px; margin-right: 20px; }