Feature/async inference (#31)

* feat: async execution

* feat: poll execution status

* feat: disable submit button while executing

* feat: execution tracing callback

* feat: handle polling events on client side

* fix: show last tool response

* fix: tool index

* fix: remove sidebar

* fix: action input width

* fix: hide input and observation on final answer

* feat: stack actions

* fix: run workers on serve

* docs: update usage
This commit is contained in:
ChungHwan Han 2023-04-08 14:10:16 +09:00 committed by GitHub
parent 880eba351e
commit 7123fe01b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 886 additions and 181 deletions

View File

@ -111,13 +111,20 @@ Some tools requires environment variables. Set envs depend on which tools you wa
### 3. Send request to EVAL ### 3. Send request to EVAL
- `POST /api/execute` - Use the Web GUI to use EVAL in ease
- Go to `http://localhost:8000` in your browser
<img src="assets/gui.png" />
- Or you can manually send request to EVAL with APIs.
- `POST /api/execute`
- `session` - session id - `session` - session id
- `files` - urls of file inputs - `files` - urls of file inputs
- `prompt` - prompt - `prompt` - prompt
- You can send request to EVAL with `curl` or `httpie`. - examples
```bash ```bash
curl -X POST -H "Content-Type: application/json" -d '{"session": "sessionid", "files": [], "prompt": "Hi there!"}' http://localhost:8000/api/execute curl -X POST -H "Content-Type: application/json" -d '{"session": "sessionid", "files": [], "prompt": "Hi there!"}' http://localhost:8000/api/execute
@ -127,7 +134,9 @@ Some tools requires environment variables. Set envs depend on which tools you wa
http POST http://localhost:8000/api/execute session=sessionid files:='[]' prompt="Hi there!" http POST http://localhost:8000/api/execute session=sessionid files:='[]' prompt="Hi there!"
``` ```
- We are planning to make a GUI for EVAL so you can use it without terminal. - It also supports asynchronous execution. You can use `POST /api/execute/async` instead of `POST /api/execute`, with same body.
- It returns `id` of the execution. Use `GET /api/execute/async/{id}` to get the result.
## TODO ## TODO

62
api/container.py Normal file
View File

@ -0,0 +1,62 @@
import os
import re
from pathlib import Path
from typing import Dict, List
from fastapi.templating import Jinja2Templates
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"])
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
uploader = StaticUploader.from_settings(
settings, path=BASE_DIR / "static", endpoint="static"
)
reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"]

View File

@ -1,72 +1,22 @@
import os
import re import re
from pathlib import Path from multiprocessing import Process
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Dict, List, TypedDict from typing import List, TypedDict
import uvicorn import uvicorn
from fastapi import FastAPI, Request, UploadFile from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel from pydantic import BaseModel
from core.agents.manager import AgentManager from api.container import agent_manager, file_handler, reload_dirs, templates, uploader
from core.handlers.base import BaseHandler, FileHandler, FileType from api.worker import get_task_result, start_worker, task_execute
from core.handlers.dataframe import CsvToDataframe
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings from env import settings
app = FastAPI() app = FastAPI()
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"])
uploader = StaticUploader.from_settings(
settings, path=BASE_DIR / "static", endpoint="static"
)
app.mount("/static", StaticFiles(directory=uploader.path), name="static") app.mount("/static", StaticFiles(directory=uploader.path), name="static")
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
class ExecuteRequest(BaseModel): class ExecuteRequest(BaseModel):
session: str session: str
@ -107,7 +57,7 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse:
files = request.files files = request.files
session = request.session session = request.session
executor = agent_manager.get_or_create_executor(session) executor = agent_manager.create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files]) promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query promptedQuery += query
@ -117,7 +67,7 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse:
except Exception as e: except Exception as e:
return {"answer": str(e), "files": []} return {"answer": str(e), "files": []}
files = re.findall(r"\[file/\S*\]", res["output"]) files = re.findall(r"\[file://\S*\]", res["output"])
files = [file[1:-1] for file in files] files = [file[1:-1] for file in files]
return { return {
@ -126,15 +76,53 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse:
} }
@app.post("/api/execute/async")
async def execute_async(request: ExecuteRequest):
query = request.prompt
files = request.files
session = request.session
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
execution = task_execute.delay(session, promptedQuery)
return {"id": execution.id}
@app.get("/api/execute/async/{execution_id}")
async def execute_async(execution_id: str):
execution = get_task_result(execution_id)
result = {}
if execution.status == "SUCCESS" and execution.result:
output = execution.result.get("output", "")
files = re.findall(r"\[file://\S*\]", output)
files = [file[1:-1] for file in files]
result = {
"answer": output,
"files": [uploader.upload(file) for file in files],
}
return {
"status": execution.status,
"info": execution.info,
"result": result,
}
def serve(): def serve():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"]) uvicorn.run("api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"])
def dev(): def dev():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run( uvicorn.run(
"api.main:app", "api.main:app",
host="0.0.0.0", host="0.0.0.0",
port=settings["EVAL_PORT"], port=settings["EVAL_PORT"],
reload=True, reload=True,
reload_dirs=[BASE_DIR / "core", BASE_DIR / "api"], reload_dirs=reload_dirs,
) )

View File

@ -44,9 +44,9 @@
</ul> </ul>
</header> </header>
<div class="d-flex flex-row"> <div class="d-flex flex-row">
<div <!-- <div
class="d-flex flex-column flex-shrink-0 p-3 bg-body-tertiary" class="d-flex flex-column flex-shrink-0 p-3 bg-body-tertiary"
style="width: 280px; height: 80vh" style="width: 240px; height: 80vh"
> >
<ul id="nav-sidebar" class="nav nav-pills flex-column mb-auto"> <ul id="nav-sidebar" class="nav nav-pills flex-column mb-auto">
<li class="nav-item"> <li class="nav-item">
@ -67,7 +67,7 @@
</li> </li>
</ul> </ul>
<hr /> <hr />
</div> </div> -->
<div class="w-100"> <div class="w-100">
<div class="container">{% block content %}{% endblock %}</div> <div class="container">{% block content %}{% endblock %}</div>

View File

@ -17,15 +17,32 @@
<label for="session" class="form-label">Session</label> <label for="session" class="form-label">Session</label>
<input id="session" name="session" class="form-control" /> <input id="session" name="session" class="form-control" />
</div> </div>
<button type="submit" class="btn btn-primary" onclick="submit(event)"> <button
id="submit"
type="submit"
class="btn btn-primary"
onclick="submit(event)"
>
Submit Submit
</button> </button>
<button
id="submit-loader"
class="btn btn-primary disabled"
style="display: none"
>
Submit
<div class="spinner-border spinner-border-sm"></div>
</button>
</div>
</div>
<div class="bg-body-tertiary border rounded-3 p-2">
<div id="actions"></div>
<div class="card m-2">
<div class="card-body">
<div id="answer" class="card-text"></div>
<div id="response-files" class="card-text"></div>
</div> </div>
</div> </div>
<div class="bg-body-tertiary border rounded-3 p-3">
<div>
<div id="answer"></div>
<div id="response-files"></div>
</div> </div>
</div> </div>
</div> </div>

42
api/worker.py Normal file
View File

@ -0,0 +1,42 @@
from celery import Celery
from celery.result import AsyncResult
from api.container import agent_manager
from env import settings
celery_app = Celery(__name__)
celery_app.conf.broker_url = settings["CELERY_BROKER_URL"]
celery_app.conf.result_backend = settings["CELERY_BROKER_URL"]
celery_app.conf.update(
task_track_started=True,
task_serializer="json",
accept_content=["json"], # Ignore other content
result_serializer="json",
enable_utc=True,
)
@celery_app.task(name="task_execute", bind=True)
def task_execute(self, session: str, prompt: str):
executor = agent_manager.create_executor(session, self)
response = executor({"input": prompt})
result = {"output": response["output"]}
previous = AsyncResult(self.request.id)
if previous and previous.info:
result.update(previous.info)
return result
def get_task_result(task_id):
return AsyncResult(task_id)
def start_worker():
celery_app.worker_main(
[
"worker",
"--loglevel=INFO",
]
)

BIN
assets/gui.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 302 KiB

View File

@ -4,6 +4,7 @@ from core.tools.factory import ToolsFactory
from env import settings from env import settings
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.output_parsers.base import BaseOutputParser from langchain.output_parsers.base import BaseOutputParser
from langchain.callbacks.base import BaseCallbackManager
from .chat_agent import ConversationalChatAgent from .chat_agent import ConversationalChatAgent
from .llm import ChatOpenAI from .llm import ChatOpenAI
@ -17,8 +18,10 @@ class AgentBuilder:
self.global_tools: list = None self.global_tools: list = None
self.toolsets = toolsets self.toolsets = toolsets
def build_llm(self): def build_llm(self, callback_manager: BaseCallbackManager = None):
self.llm = ChatOpenAI(temperature=0) self.llm = ChatOpenAI(
temperature=0, callback_manager=callback_manager, verbose=True
)
self.llm.check_access() self.llm.check_access()
def build_parser(self): def build_parser(self):
@ -40,6 +43,12 @@ class AgentBuilder:
*ToolsFactory.create_global_tools(self.toolsets), *ToolsFactory.create_global_tools(self.toolsets),
] ]
def get_parser(self):
if self.parser is None:
raise ValueError("Parser is not initialized yet")
return self.parser
def get_global_tools(self): def get_global_tools(self):
if self.global_tools is None: if self.global_tools is None:
raise ValueError("Global tools are not initialized yet") raise ValueError("Global tools are not initialized yet")

View File

@ -2,22 +2,45 @@ from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
from celery import Task
from ansi import ANSI, Color, Style, dim_multiline from ansi import ANSI, Color, Style, dim_multiline
from logger import logger from logger import logger
class EVALCallbackHandler(BaseCallbackHandler): class EVALCallbackHandler(BaseCallbackHandler):
@property
def ignore_llm(self) -> bool:
return False
def set_parser(self, parser) -> None:
self.parser = parser
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
pass pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass text = response.generations[0][0].text
parsed = self.parser.parse_all(text)
logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + parsed["plan"])
logger.info(ANSI("What I Did").to(Color.blue()) + ": " + parsed["what_i_did"])
logger.info(
ANSI("Action").to(Color.cyan())
+ ": "
+ ANSI(parsed["action"]).to(Style.bold())
)
logger.info(
ANSI("Input").to(Color.cyan())
+ ": "
+ dim_multiline(parsed["action_input"])
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
pass logger.info(ANSI(f"on_llm_new_token {token}").to(Color.green(), Style.italic()))
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@ -85,3 +108,90 @@ class EVALCallbackHandler(BaseCallbackHandler):
+ ": " + ": "
+ dim_multiline(finish.return_values.get("output", "")) + dim_multiline(finish.return_values.get("output", ""))
) )
class ExecutionTracingCallbackHandler(BaseCallbackHandler):
def __init__(self, execution: Task):
self.execution = execution
self.index = 0
def set_parser(self, parser) -> None:
self.parser = parser
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
text = response.generations[0][0].text
parsed = self.parser.parse_all(text)
self.index += 1
parsed["index"] = self.index
self.execution.update_state(state="LLM_END", meta=parsed)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self.execution.update_state(state="CHAIN_ERROR", meta={"error": str(error)})
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
previous = self.execution.AsyncResult(self.execution.request.id)
self.execution.update_state(
state="TOOL_END", meta={**previous.info, "observation": output}
)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
previous = self.execution.AsyncResult(self.execution.request.id)
self.execution.update_state(
state="TOOL_ERROR", meta={**previous.info, "error": str(error)}
)
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
pass

View File

@ -1,9 +1,9 @@
from typing import Dict from typing import Dict, Optional
from celery import Task
from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.tools import BaseTool
from langchain.callbacks import set_handler
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
from langchain.callbacks import set_handler
from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
@ -11,68 +11,73 @@ from core.tools.base import BaseToolSet
from core.tools.factory import ToolsFactory from core.tools.factory import ToolsFactory
from .builder import AgentBuilder from .builder import AgentBuilder
from .callback import EVALCallbackHandler from .callback import EVALCallbackHandler, ExecutionTracingCallbackHandler
callback_manager = CallbackManager([EVALCallbackHandler()])
set_handler(EVALCallbackHandler()) set_handler(EVALCallbackHandler())
class AgentManager: class AgentManager:
def __init__( def __init__(
self, self,
agent: Agent,
global_tools: list[BaseTool],
toolsets: list[BaseToolSet] = [], toolsets: list[BaseToolSet] = [],
): ):
self.agent: Agent = agent
self.global_tools: list[BaseTool] = global_tools
self.toolsets: list[BaseToolSet] = toolsets self.toolsets: list[BaseToolSet] = toolsets
self.memories: Dict[str, BaseChatMemory] = {}
self.executors: Dict[str, AgentExecutor] = {} self.executors: Dict[str, AgentExecutor] = {}
def create_memory(self) -> BaseChatMemory: def create_memory(self) -> BaseChatMemory:
return ConversationBufferMemory(memory_key="chat_history", return_messages=True) return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def create_executor(self, session: str) -> AgentExecutor: def get_or_create_memory(self, session: str) -> BaseChatMemory:
memory: BaseChatMemory = self.create_memory() if not (session in self.memories):
self.memories[session] = self.create_memory()
return self.memories[session]
def create_executor(
self, session: str, execution: Optional[Task] = None
) -> AgentExecutor:
builder = AgentBuilder(self.toolsets)
builder.build_parser()
callbacks = []
eval_callback = EVALCallbackHandler()
eval_callback.set_parser(builder.get_parser())
callbacks.append(eval_callback)
if execution:
execution_callback = ExecutionTracingCallbackHandler(execution)
execution_callback.set_parser(builder.get_parser())
callbacks.append(execution_callback)
callback_manager = CallbackManager(callbacks)
builder.build_llm(callback_manager)
builder.build_global_tools()
memory: BaseChatMemory = self.get_or_create_memory(session)
tools = [ tools = [
*self.global_tools, *builder.get_global_tools(),
*ToolsFactory.create_per_session_tools( *ToolsFactory.create_per_session_tools(
self.toolsets, self.toolsets,
get_session=lambda: (session, self.executors[session]), get_session=lambda: (session, self.executors[session]),
), ),
] ]
for tool in tools:
tool.set_callback_manager(callback_manager)
return AgentExecutor.from_agent_and_tools( for tool in tools:
agent=self.agent, tool.callback_manager = callback_manager
executor = AgentExecutor.from_agent_and_tools(
agent=builder.get_agent(),
tools=tools, tools=tools,
memory=memory, memory=memory,
callback_manager=callback_manager, callback_manager=callback_manager,
verbose=True, verbose=True,
) )
self.executors[session] = executor
def remove_executor(self, session: str) -> None: return executor
if session in self.executors:
del self.executors[session]
def get_or_create_executor(self, session: str) -> AgentExecutor:
if not (session in self.executors):
self.executors[session] = self.create_executor(session=session)
return self.executors[session]
@staticmethod @staticmethod
def create(toolsets: list[BaseToolSet]) -> "AgentManager": def create(toolsets: list[BaseToolSet]) -> "AgentManager":
builder = AgentBuilder(toolsets)
builder.build_llm()
builder.build_parser()
builder.build_global_tools()
agent = builder.get_agent()
global_tools = builder.get_global_tools()
return AgentManager( return AgentManager(
agent=agent,
global_tools=global_tools,
toolsets=toolsets, toolsets=toolsets,
) )

View File

@ -1,16 +1,31 @@
import re import re
import time
from typing import Dict from typing import Dict
from langchain.output_parsers.base import BaseOutputParser from langchain.output_parsers.base import BaseOutputParser
from ansi import ANSI, Color, Style
from core.agents.callback import dim_multiline
from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS
from logger import logger
class EvalOutputParser(BaseOutputParser): class EvalOutputParser(BaseOutputParser):
@staticmethod
def parse_all(text: str) -> Dict[str, str]:
regex = r"Action: (.*?)[\n]Plan:(.*)[\n]What I Did:(.*)[\n]Action Input: (.*)"
match = re.search(regex, text, re.DOTALL)
if not match:
raise Exception("parse error")
action = match.group(1).strip()
plan = match.group(2)
what_i_did = match.group(3)
action_input = match.group(4).strip(" ").strip('"')
return {
"action": action,
"plan": plan,
"what_i_did": what_i_did,
"action_input": action_input,
}
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return EVAL_FORMAT_INSTRUCTIONS return EVAL_FORMAT_INSTRUCTIONS
@ -20,21 +35,9 @@ class EvalOutputParser(BaseOutputParser):
if not match: if not match:
raise Exception("parse error") raise Exception("parse error")
action = match.group(1).strip() parsed = EvalOutputParser.parse_all(text)
plan = match.group(2)
what_i_did = match.group(3)
action_input = match.group(4)
logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + plan) return {"action": parsed["action"], "action_input": parsed["action_input"]}
logger.info(ANSI("What I Did").to(Color.blue()) + ": " + what_i_did)
time.sleep(1)
logger.info(
ANSI("Action").to(Color.cyan()) + ": " + ANSI(action).to(Style.bold())
)
time.sleep(1)
logger.info(ANSI("Input").to(Color.cyan()) + ": " + dim_multiline(action_input))
time.sleep(1)
return {"action": action, "action_input": action_input.strip(" ").strip('"')}
def __str__(self): def __str__(self):
return "EvalOutputParser" return "EvalOutputParser"

View File

@ -3,7 +3,7 @@ EVAL_PREFIX = """{bot_name} can execute any user's request.
{bot_name} has permission to handle one instance and can handle the environment in it at will. {bot_name} has permission to handle one instance and can handle the environment in it at will.
You can code, run, debug, and test yourself. You can correct the code appropriately by looking at the error message. You can code, run, debug, and test yourself. You can correct the code appropriately by looking at the error message.
I can understand, process, and create various types of files. Every image, dataframe, audio, video must be restored in file/ directory. I can understand, process, and create various types of files.
{bot_name} can do whatever it takes to execute the user's request. Let's think step by step. {bot_name} can do whatever it takes to execute the user's request. Let's think step by step.
""" """
@ -37,7 +37,7 @@ EVAL_SUFFIX = """TOOLS
{bot_name} can ask the user to use tools to look up information that may be helpful in answering the users original question. {bot_name} can ask the user to use tools to look up information that may be helpful in answering the users original question.
You are very strict to the filename correctness and will never fake a file name if it does not exist. You are very strict to the filename correctness and will never fake a file name if it does not exist.
You will remember to provide the file name loyally if it's provided in the last tool observation. You will remember to provide the file name loyally if it's provided in the last tool observation.
If you have to include files in your response, you must move the files into file/ directory and provide the filename in [file/FILENAME] format. It must be wrapped in square brackets. If you have to include files in your response, you must provide the filepath in [file://filepath] format. It must be wrapped in square brackets.
The tools the human can use are: The tools the human can use are:

View File

@ -16,6 +16,10 @@ services:
- "8000:8000" # eval port - "8000:8000" # eval port
env_file: env_file:
- .env - .env
environment:
- CELERY_BROKER_URL=redis://redis:6379
depends_on:
- redis
eval.gpu: eval.gpu:
container_name: eval.gpu container_name: eval.gpu
@ -40,3 +44,8 @@ services:
- driver: nvidia - driver: nvidia
device_ids: ["1"] # You can choose which GPU to use device_ids: ["1"] # You can choose which GPU to use
capabilities: [gpu] capabilities: [gpu]
depends_on:
- redis
redis:
image: redis:alpine

2
env.py
View File

@ -12,6 +12,7 @@ class DotEnv(TypedDict):
EVAL_PORT: int EVAL_PORT: int
SERVER: str SERVER: str
CELERY_BROKER_URL: str
USE_GPU: bool # optional USE_GPU: bool # optional
PLAYGROUND_DIR: str # optional PLAYGROUND_DIR: str # optional
LOG_LEVEL: str # optional LOG_LEVEL: str # optional
@ -31,6 +32,7 @@ EVAL_PORT = int(os.getenv("EVAL_PORT", 8000))
settings: DotEnv = { settings: DotEnv = {
"EVAL_PORT": EVAL_PORT, "EVAL_PORT": EVAL_PORT,
"MODEL_NAME": os.getenv("MODEL_NAME", "gpt-4"), "MODEL_NAME": os.getenv("MODEL_NAME", "gpt-4"),
"CELERY_BROKER_URL": os.getenv("CELERY_BROKER_URL", "redis://localhost:6379"),
"SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"), "SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"),
"USE_GPU": os.getenv("USE_GPU", "False").lower() == "true", "USE_GPU": os.getenv("USE_GPU", "False").lower() == "true",
"PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"), "PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"),

235
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. # This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
[[package]] [[package]]
name = "accelerate" name = "accelerate"
@ -153,6 +153,21 @@ files = [
[package.dependencies] [package.dependencies]
frozenlist = ">=1.1.0" frozenlist = ">=1.1.0"
[[package]]
name = "amqp"
version = "5.1.1"
description = "Low-level AMQP client for Python (fork of amqplib)."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"},
{file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"},
]
[package.dependencies]
vine = ">=5.0.0"
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "3.6.2" version = "3.6.2"
@ -224,6 +239,18 @@ soupsieve = ">1.2"
html5lib = ["html5lib"] html5lib = ["html5lib"]
lxml = ["lxml"] lxml = ["lxml"]
[[package]]
name = "billiard"
version = "3.6.4.0"
description = "Python multiprocessing fork with improvements and bugfixes"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "billiard-3.6.4.0-py3-none-any.whl", hash = "sha256:87103ea78fa6ab4d5c751c4909bcff74617d985de7fa8b672cf8618afd5a875b"},
{file = "billiard-3.6.4.0.tar.gz", hash = "sha256:299de5a8da28a783d51b197d496bef4f1595dd023a93a4f59dde1886ae905547"},
]
[[package]] [[package]]
name = "bitsandbytes" name = "bitsandbytes"
version = "0.37.2" version = "0.37.2"
@ -325,6 +352,61 @@ urllib3 = ">=1.25.4,<1.27"
[package.extras] [package.extras]
crt = ["awscrt (==0.16.9)"] crt = ["awscrt (==0.16.9)"]
[[package]]
name = "celery"
version = "5.2.7"
description = "Distributed Task Queue."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "celery-5.2.7-py3-none-any.whl", hash = "sha256:138420c020cd58d6707e6257b6beda91fd39af7afde5d36c6334d175302c0e14"},
{file = "celery-5.2.7.tar.gz", hash = "sha256:fafbd82934d30f8a004f81e8f7a062e31413a23d444be8ee3326553915958c6d"},
]
[package.dependencies]
billiard = ">=3.6.4.0,<4.0"
click = ">=8.0.3,<9.0"
click-didyoumean = ">=0.0.3"
click-plugins = ">=1.1.1"
click-repl = ">=0.2.0"
kombu = ">=5.2.3,<6.0"
pytz = ">=2021.3"
vine = ">=5.0.0,<6.0"
[package.extras]
arangodb = ["pyArango (>=1.3.2)"]
auth = ["cryptography"]
azureblockblob = ["azure-storage-blob (==12.9.0)"]
brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
cassandra = ["cassandra-driver (<3.21.0)"]
consul = ["python-consul2"]
cosmosdbsql = ["pydocumentdb (==2.3.2)"]
couchbase = ["couchbase (>=3.0.0)"]
couchdb = ["pycouchdb"]
django = ["Django (>=1.11)"]
dynamodb = ["boto3 (>=1.9.178)"]
elasticsearch = ["elasticsearch"]
eventlet = ["eventlet (>=0.32.0)"]
gevent = ["gevent (>=1.5.0)"]
librabbitmq = ["librabbitmq (>=1.5.0)"]
memcache = ["pylibmc"]
mongodb = ["pymongo[srv] (>=3.11.1)"]
msgpack = ["msgpack"]
pymemcache = ["python-memcached"]
pyro = ["pyro4"]
pytest = ["pytest-celery"]
redis = ["redis (>=3.4.1,!=4.0.0,!=4.0.1)"]
s3 = ["boto3 (>=1.9.125)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
solar = ["ephem"]
sqlalchemy = ["sqlalchemy"]
sqs = ["kombu[sqs]"]
tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
zstd = ["zstandard"]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2022.12.7" version = "2022.12.7"
@ -437,6 +519,56 @@ files = [
[package.dependencies] [package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""} colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "click-didyoumean"
version = "0.3.0"
description = "Enables git-like *did-you-mean* feature in click"
category = "main"
optional = false
python-versions = ">=3.6.2,<4.0.0"
files = [
{file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"},
{file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"},
]
[package.dependencies]
click = ">=7"
[[package]]
name = "click-plugins"
version = "1.1.1"
description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
category = "main"
optional = false
python-versions = "*"
files = [
{file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
{file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
]
[package.dependencies]
click = ">=4.0"
[package.extras]
dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
[[package]]
name = "click-repl"
version = "0.2.0"
description = "REPL plugin for Click"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "click-repl-0.2.0.tar.gz", hash = "sha256:cd12f68d745bf6151210790540b4cb064c7b13e571bc64b6957d98d120dacfd8"},
{file = "click_repl-0.2.0-py3-none-any.whl", hash = "sha256:94b3fbbc9406a236f176e0506524b2937e4b23b6f4c0c0b2a0a83f8a64e9194b"},
]
[package.dependencies]
click = "*"
prompt-toolkit = "*"
six = "*"
[[package]] [[package]]
name = "cmake" name = "cmake"
version = "3.26.1" version = "3.26.1"
@ -844,6 +976,38 @@ files = [
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
] ]
[[package]]
name = "kombu"
version = "5.2.4"
description = "Messaging library for Python."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "kombu-5.2.4-py3-none-any.whl", hash = "sha256:8b213b24293d3417bcf0d2f5537b7f756079e3ea232a8386dcc89a59fd2361a4"},
{file = "kombu-5.2.4.tar.gz", hash = "sha256:37cee3ee725f94ea8bb173eaab7c1760203ea53bbebae226328600f9d2799610"},
]
[package.dependencies]
amqp = ">=5.0.9,<6.0.0"
vine = "*"
[package.extras]
azureservicebus = ["azure-servicebus (>=7.0.0)"]
azurestoragequeues = ["azure-storage-queue"]
consul = ["python-consul (>=0.6.0)"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
mongodb = ["pymongo (>=3.3.0,<3.12.1)"]
msgpack = ["msgpack"]
pyro = ["pyro4"]
qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"]
redis = ["redis (>=3.4.1,!=4.0.0,!=4.0.1)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
sqlalchemy = ["sqlalchemy"]
sqs = ["boto3 (>=1.9.12)", "pycurl (>=7.44.1,<7.45.0)", "urllib3 (>=1.26.7)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
[[package]] [[package]]
name = "langchain" name = "langchain"
version = "0.0.115" version = "0.0.115"
@ -1528,6 +1692,21 @@ files = [
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
[[package]]
name = "prompt-toolkit"
version = "3.0.38"
description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.38-py3-none-any.whl", hash = "sha256:45ea77a2f7c60418850331366c81cf6b5b9cf4c7fd34616f733c5427e6abbb1f"},
{file = "prompt_toolkit-3.0.38.tar.gz", hash = "sha256:23ac5d50538a9a38c8bde05fecb47d0b403ecd0662857a86f886f798563d5b9b"},
]
[package.dependencies]
wcwidth = "*"
[[package]] [[package]]
name = "psutil" name = "psutil"
version = "5.9.4" version = "5.9.4"
@ -1808,6 +1987,25 @@ files = [
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
] ]
[[package]]
name = "redis"
version = "4.5.4"
description = "Python client for Redis database and key-value store"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "redis-4.5.4-py3-none-any.whl", hash = "sha256:2c19e6767c474f2e85167909061d525ed65bea9301c0770bb151e041b7ac89a2"},
{file = "redis-4.5.4.tar.gz", hash = "sha256:73ec35da4da267d6847e47f68730fdd5f62e2ca69e3ef5885c6a78a9374c3893"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.2", markers = "python_version <= \"3.11.2\""}
[package.extras]
hiredis = ["hiredis (>=1.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
[[package]] [[package]]
name = "regex" name = "regex"
version = "2023.3.23" version = "2023.3.23"
@ -2411,6 +2609,15 @@ category = "dev"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"},
{file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"},
{file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"},
{file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"},
{file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"},
{file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"},
{file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"},
{file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"},
{file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"},
{file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"},
{file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"},
{file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"},
@ -2496,6 +2703,30 @@ h11 = ">=0.8"
[package.extras] [package.extras]
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[[package]]
name = "vine"
version = "5.0.0"
description = "Promises, promises, promises."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"},
{file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"},
]
[[package]]
name = "wcwidth"
version = "0.2.6"
description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"},
{file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"},
]
[[package]] [[package]]
name = "wheel" name = "wheel"
version = "0.40.0" version = "0.40.0"
@ -2633,4 +2864,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "e6e707098fa68cc228ba4bd454533f460f354991091144c0e3bde29ae0b409e1" content-hash = "3a7cbc62858401d620de4eeab75906ea8caa97d780632e58e1d6316681228d13"

View File

@ -24,6 +24,8 @@ uvicorn = "^0.21.1"
python-ptrace = "^0.9.8" python-ptrace = "^0.9.8"
jinja2 = "^3.1.2" jinja2 = "^3.1.2"
python-multipart = "^0.0.6" python-multipart = "^0.0.6"
celery = "^5.2.7"
redis = "^4.5.4"
[tool.poetry.group.gpu] [tool.poetry.group.gpu]
optional = true optional = true

View File

@ -1,6 +1,25 @@
const setAnswer = (answer, files) => { const $ = (selector) => document.querySelector(selector);
document.getElementById("answer").textContent = answer;
const filesDiv = document.getElementById("response-files"); const setLoader = (isLoading) => {
const button = $("#submit");
const loader = $("#submit-loader");
if (isLoading) {
button.style.display = "none";
loader.style.display = "block";
} else {
button.style.display = "block";
loader.style.display = "none";
}
};
const setAnswer = (answer, files = []) => {
if (answer) {
$("#answer").textContent = answer;
} else {
$("#answer").innerHTML = createSpinner();
}
const filesDiv = $("#response-files");
filesDiv.innerHTML = ""; filesDiv.innerHTML = "";
files.forEach((file) => { files.forEach((file) => {
const a = document.createElement("a"); const a = document.createElement("a");
@ -12,10 +31,27 @@ const setAnswer = (answer, files) => {
}); });
}; };
const submit = async () => { class EvalApi {
setAnswer("Loading...", []); constructor({ onComplete, onError, onSettle, onLLMEnd, onToolEnd }) {
this.executionId = null;
this.pollInterval = null;
this.onComplete = (answer, files, info) => {
onComplete(answer, files, info);
onSettle();
};
this.onError = (error) => {
onError(error);
onSettle();
};
this.onLLMEnd = (info) => {
onLLMEnd(info);
};
this.onToolEnd = (info) => {
onToolEnd(info);
};
}
async uploadFiles(rawfiles) {
const files = []; const files = [];
const rawfiles = document.getElementById("files").files;
if (rawfiles.length > 0) { if (rawfiles.length > 0) {
const formData = new FormData(); const formData = new FormData();
@ -30,11 +66,12 @@ const submit = async () => {
files.push(...urls); files.push(...urls);
} }
const prompt = document.getElementById("prompt").value; return files;
const session = document.getElementById("session").value; }
async execute(prompt, session, files) {
try { try {
const response = await fetch("/api/execute", { const response = await fetch("/api/execute/async", {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -48,14 +85,193 @@ const submit = async () => {
if (response.status !== 200) { if (response.status !== 200) {
throw new Error(await response.text()); throw new Error(await response.text());
} }
const { answer, files: responseFiles } = await response.json(); const { id: executionId } = await response.json();
setAnswer(answer, responseFiles); this.executionId = executionId;
this.pollInterval = setInterval(this.poll.bind(this), 1000);
} catch (e) { } catch (e) {
setAnswer("Error: " + e.message, []); clearInterval(this.pollInterval);
this.onError(e);
} }
}
async poll() {
try {
const response = await fetch(`/api/execute/async/${this.executionId}`, {
method: "GET",
});
if (response.status !== 200) {
throw new Error(await response.text());
}
const { status, result, info } = await response.json();
switch (status) {
case "PENDING":
break;
case "FAILURE":
throw new Error("Execution failed");
case "LLM_END":
this.onLLMEnd(info);
break;
case "TOOL_END":
this.onToolEnd(info);
break;
case "SUCCESS":
clearInterval(this.pollInterval);
this.onComplete(result.answer, result.files, info);
break;
}
} catch (e) {
clearInterval(this.pollInterval);
this.onError(e);
}
}
}
const submit = async () => {
setAnswer("");
setLoader(true);
const actions = $("#actions");
actions.innerHTML = "";
let currentActionIndex = 0;
const onInfo = (info) => {
if (currentActionIndex >= info.index) {
return;
}
currentActionIndex = info.index;
const w = document.createElement("div");
w.innerHTML = createActionCard(
info.index,
info.action,
info.action_input,
info.what_i_did,
info.plan,
info.observation
);
actions.appendChild(w);
};
const api = new EvalApi({
onSettle: () => setLoader(false),
onError: (error) => setAnswer(`Error: ${error.message}`, []),
onComplete: (answer, files, info) => {
setAnswer(answer, files);
onInfo(info);
},
onLLMEnd: onInfo,
onToolEnd: onInfo,
});
const prompt = $("#prompt").value;
const session = $("#session").value;
const files = await api.uploadFiles($("#files").files);
await api.execute(prompt, session, files);
}; };
const setRandomSessionId = () => { const setRandomSessionId = () => {
const sessionId = Math.random().toString(36).substring(2, 15); const sessionId = Math.random().toString(36).substring(2, 15);
document.getElementById("session").value = sessionId; $("#session").value = sessionId;
}; };
const createSpinner = () => `
<div class="text-center">
<div class="spinner-border m-3"></div>
</div>
`;
const createActionCard = (
index,
action,
input,
whatIdid,
plan,
observation
) => `
<div class="accordion m-2">
<div class="accordion-item">
<h2 class="accordion-header">
<button class="accordion-button">
<span class="text-secondary">
Action #${index}
</span>
<span class="mx-1">-</span>
<span class="fw-bold">
${action}
</span>
</button>
</h2>
<div class="accordion-collapse collapse show">
<div class="accordion-body">
<table class="table">
<tbody>
${
action !== "Final Answer"
? `<tr>
<th style="width: 100px">Input</th>
<td><div>${input}</div></td>
</tr>`
: ""
}
<tr>
<th style="width: 100px">What I Did</th>
<td><div>${whatIdid}</div></td>
</tr>
</tbody>
</table>
<table class="table">
<thead>
<tr>
<th colspan="2">Plan</th>
</tr>
</thead>
<tbody>
${plan
.split("- ")
.map((p) => p.trim())
.filter((p) => p.length > 0)
.map(
(p) => `
<tr>
${
p.startsWith("[ ]")
? `<td><input class="form-check-input" type="checkbox" /></td>
<td>${p.replace("[ ]", "")}</td>`
: ""
}
${
p.startsWith("[x]")
? `<td><input class="form-check-input" type="checkbox" checked/></td>
<td>${p.replace("[x]", "")}</td>`
: ""
}
</tr>`
)
.join("")}
</tbody>
</table>
${
action !== "Final Answer"
? `<table class="table">
<thead>
<tr>
<th colspan="2">Observation</th>
</tr>
</thead>
<tbody>
<tr>
<td>
<div>${observation}</div>
</td>
</tr>
</tbody>
</table>`
: ""
}
</div>
</div>
</div>
</div>`;

View File

@ -1,7 +1,7 @@
.logo { .logo {
border-radius: 50%; border-radius: 50%;
overflow: hidden; overflow: hidden;
height: 64px; height: 48px;
width: 64px; width: 48px;
margin-right: 20px; margin-right: 20px;
} }