feat: async execution

pull/31/head
hanchchch 1 year ago
parent a15ba07eca
commit d881b3a703

@ -0,0 +1,62 @@
import os
import re
from pathlib import Path
from typing import Dict, List
from fastapi.templating import Jinja2Templates
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"])
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
uploader = StaticUploader.from_settings(
settings, path=BASE_DIR / "static", endpoint="static"
)
reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"]

@ -1,73 +1,22 @@
import os
import re
from pathlib import Path
from multiprocessing import Process
from tempfile import NamedTemporaryFile
from typing import Dict, List, TypedDict
from typing import List, TypedDict
import uvicorn
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from api.container import agent_manager, file_handler, reload_dirs, templates, uploader
from api.worker import get_task_result, start_worker, task_execute
from env import settings
app = FastAPI()
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(BASE_DIR / settings["PLAYGROUND_DIR"])
uploader = StaticUploader.from_settings(
settings, path=BASE_DIR / "static", endpoint="static"
)
app.mount("/static", StaticFiles(directory=uploader.path), name="static")
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
toolsets: List[BaseToolSet] = [
Terminal(),
CodeEditor(),
RequestsGet(),
ExitConversation(),
]
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
if settings["USE_GPU"]:
import torch
from core.handlers.image import ImageCaptioning
from core.tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
if torch.cuda.is_available():
toolsets.extend(
[
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
)
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
agent_manager = AgentManager.create(toolsets=toolsets)
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
class ExecuteRequest(BaseModel):
session: str
@ -127,15 +76,40 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse:
}
@app.post("/api/execute/async")
async def execute_async(request: ExecuteRequest):
query = request.prompt
files = request.files
session = request.session
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
execution = task_execute.delay(session, promptedQuery)
return {"id": execution.id}
@app.get("/api/execute/async/{execution_id}")
async def execute_async(execution_id: str):
result = get_task_result(execution_id)
return {
"task_id": execution_id,
"status": result.status,
"result": result.result,
}
def serve():
uvicorn.run("api.main:app", host="0.0.0.0", port=settings["EVAL_PORT"])
def dev():
p = Process(target=start_worker, args=[])
p.start()
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=settings["EVAL_PORT"],
reload=True,
reload_dirs=[BASE_DIR / "core", BASE_DIR / "api"],
reload_dirs=reload_dirs,
)

@ -0,0 +1,35 @@
from celery import Celery
from celery.result import AsyncResult
from api.container import agent_manager
from env import settings
celery_app = Celery(__name__)
celery_app.conf.broker_url = settings["CELERY_BROKER_URL"]
celery_app.conf.result_backend = settings["CELERY_BROKER_URL"]
celery_app.conf.update(
task_serializer="json",
accept_content=["json"], # Ignore other content
result_serializer="json",
enable_utc=True,
)
@celery_app.task(name="task_execute")
def task_execute(session: str, prompt: str):
executor = agent_manager.get_or_create_executor(session)
response = executor({"input": prompt})
return {"output": response["output"]}
def get_task_result(task_id):
return AsyncResult(task_id)
def start_worker():
celery_app.worker_main(
[
"worker",
"--loglevel=INFO",
]
)

@ -16,6 +16,10 @@ services:
- "8000:8000" # eval port
env_file:
- .env
environment:
- CELERY_BROKER_URL=redis://redis:6379
depends_on:
- redis
eval.gpu:
container_name: eval.gpu
@ -40,3 +44,10 @@ services:
- driver: nvidia
device_ids: ["1"] # You can choose which GPU to use
capabilities: [gpu]
depends_on:
- redis
redis:
image: redis:alpine
ports:
- "6379:6379"

@ -12,6 +12,7 @@ class DotEnv(TypedDict):
EVAL_PORT: int
SERVER: str
CELERY_BROKER_URL: str
USE_GPU: bool # optional
PLAYGROUND_DIR: str # optional
LOG_LEVEL: str # optional
@ -31,6 +32,7 @@ EVAL_PORT = int(os.getenv("EVAL_PORT", 8000))
settings: DotEnv = {
"EVAL_PORT": EVAL_PORT,
"MODEL_NAME": os.getenv("MODEL_NAME", "gpt-4"),
"CELERY_BROKER_URL": os.getenv("CELERY_BROKER_URL", "redis://localhost:6379"),
"SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"),
"USE_GPU": os.getenv("USE_GPU", "False").lower() == "true",
"PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"),

235
poetry.lock generated

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
[[package]]
name = "accelerate"
@ -153,6 +153,21 @@ files = [
[package.dependencies]
frozenlist = ">=1.1.0"
[[package]]
name = "amqp"
version = "5.1.1"
description = "Low-level AMQP client for Python (fork of amqplib)."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"},
{file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"},
]
[package.dependencies]
vine = ">=5.0.0"
[[package]]
name = "anyio"
version = "3.6.2"
@ -224,6 +239,18 @@ soupsieve = ">1.2"
html5lib = ["html5lib"]
lxml = ["lxml"]
[[package]]
name = "billiard"
version = "3.6.4.0"
description = "Python multiprocessing fork with improvements and bugfixes"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "billiard-3.6.4.0-py3-none-any.whl", hash = "sha256:87103ea78fa6ab4d5c751c4909bcff74617d985de7fa8b672cf8618afd5a875b"},
{file = "billiard-3.6.4.0.tar.gz", hash = "sha256:299de5a8da28a783d51b197d496bef4f1595dd023a93a4f59dde1886ae905547"},
]
[[package]]
name = "bitsandbytes"
version = "0.37.2"
@ -325,6 +352,61 @@ urllib3 = ">=1.25.4,<1.27"
[package.extras]
crt = ["awscrt (==0.16.9)"]
[[package]]
name = "celery"
version = "5.2.7"
description = "Distributed Task Queue."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "celery-5.2.7-py3-none-any.whl", hash = "sha256:138420c020cd58d6707e6257b6beda91fd39af7afde5d36c6334d175302c0e14"},
{file = "celery-5.2.7.tar.gz", hash = "sha256:fafbd82934d30f8a004f81e8f7a062e31413a23d444be8ee3326553915958c6d"},
]
[package.dependencies]
billiard = ">=3.6.4.0,<4.0"
click = ">=8.0.3,<9.0"
click-didyoumean = ">=0.0.3"
click-plugins = ">=1.1.1"
click-repl = ">=0.2.0"
kombu = ">=5.2.3,<6.0"
pytz = ">=2021.3"
vine = ">=5.0.0,<6.0"
[package.extras]
arangodb = ["pyArango (>=1.3.2)"]
auth = ["cryptography"]
azureblockblob = ["azure-storage-blob (==12.9.0)"]
brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
cassandra = ["cassandra-driver (<3.21.0)"]
consul = ["python-consul2"]
cosmosdbsql = ["pydocumentdb (==2.3.2)"]
couchbase = ["couchbase (>=3.0.0)"]
couchdb = ["pycouchdb"]
django = ["Django (>=1.11)"]
dynamodb = ["boto3 (>=1.9.178)"]
elasticsearch = ["elasticsearch"]
eventlet = ["eventlet (>=0.32.0)"]
gevent = ["gevent (>=1.5.0)"]
librabbitmq = ["librabbitmq (>=1.5.0)"]
memcache = ["pylibmc"]
mongodb = ["pymongo[srv] (>=3.11.1)"]
msgpack = ["msgpack"]
pymemcache = ["python-memcached"]
pyro = ["pyro4"]
pytest = ["pytest-celery"]
redis = ["redis (>=3.4.1,!=4.0.0,!=4.0.1)"]
s3 = ["boto3 (>=1.9.125)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
solar = ["ephem"]
sqlalchemy = ["sqlalchemy"]
sqs = ["kombu[sqs]"]
tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
zstd = ["zstandard"]
[[package]]
name = "certifi"
version = "2022.12.7"
@ -437,6 +519,56 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "click-didyoumean"
version = "0.3.0"
description = "Enables git-like *did-you-mean* feature in click"
category = "main"
optional = false
python-versions = ">=3.6.2,<4.0.0"
files = [
{file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"},
{file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"},
]
[package.dependencies]
click = ">=7"
[[package]]
name = "click-plugins"
version = "1.1.1"
description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
category = "main"
optional = false
python-versions = "*"
files = [
{file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
{file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
]
[package.dependencies]
click = ">=4.0"
[package.extras]
dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
[[package]]
name = "click-repl"
version = "0.2.0"
description = "REPL plugin for Click"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "click-repl-0.2.0.tar.gz", hash = "sha256:cd12f68d745bf6151210790540b4cb064c7b13e571bc64b6957d98d120dacfd8"},
{file = "click_repl-0.2.0-py3-none-any.whl", hash = "sha256:94b3fbbc9406a236f176e0506524b2937e4b23b6f4c0c0b2a0a83f8a64e9194b"},
]
[package.dependencies]
click = "*"
prompt-toolkit = "*"
six = "*"
[[package]]
name = "cmake"
version = "3.26.1"
@ -844,6 +976,38 @@ files = [
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
]
[[package]]
name = "kombu"
version = "5.2.4"
description = "Messaging library for Python."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "kombu-5.2.4-py3-none-any.whl", hash = "sha256:8b213b24293d3417bcf0d2f5537b7f756079e3ea232a8386dcc89a59fd2361a4"},
{file = "kombu-5.2.4.tar.gz", hash = "sha256:37cee3ee725f94ea8bb173eaab7c1760203ea53bbebae226328600f9d2799610"},
]
[package.dependencies]
amqp = ">=5.0.9,<6.0.0"
vine = "*"
[package.extras]
azureservicebus = ["azure-servicebus (>=7.0.0)"]
azurestoragequeues = ["azure-storage-queue"]
consul = ["python-consul (>=0.6.0)"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
mongodb = ["pymongo (>=3.3.0,<3.12.1)"]
msgpack = ["msgpack"]
pyro = ["pyro4"]
qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"]
redis = ["redis (>=3.4.1,!=4.0.0,!=4.0.1)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
sqlalchemy = ["sqlalchemy"]
sqs = ["boto3 (>=1.9.12)", "pycurl (>=7.44.1,<7.45.0)", "urllib3 (>=1.26.7)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
[[package]]
name = "langchain"
version = "0.0.115"
@ -1528,6 +1692,21 @@ files = [
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
[[package]]
name = "prompt-toolkit"
version = "3.0.38"
description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.38-py3-none-any.whl", hash = "sha256:45ea77a2f7c60418850331366c81cf6b5b9cf4c7fd34616f733c5427e6abbb1f"},
{file = "prompt_toolkit-3.0.38.tar.gz", hash = "sha256:23ac5d50538a9a38c8bde05fecb47d0b403ecd0662857a86f886f798563d5b9b"},
]
[package.dependencies]
wcwidth = "*"
[[package]]
name = "psutil"
version = "5.9.4"
@ -1808,6 +1987,25 @@ files = [
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
]
[[package]]
name = "redis"
version = "4.5.4"
description = "Python client for Redis database and key-value store"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "redis-4.5.4-py3-none-any.whl", hash = "sha256:2c19e6767c474f2e85167909061d525ed65bea9301c0770bb151e041b7ac89a2"},
{file = "redis-4.5.4.tar.gz", hash = "sha256:73ec35da4da267d6847e47f68730fdd5f62e2ca69e3ef5885c6a78a9374c3893"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.2", markers = "python_version <= \"3.11.2\""}
[package.extras]
hiredis = ["hiredis (>=1.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
[[package]]
name = "regex"
version = "2023.3.23"
@ -2411,6 +2609,15 @@ category = "dev"
optional = false
python-versions = "*"
files = [
{file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"},
{file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"},
{file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"},
{file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"},
{file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"},
{file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"},
{file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"},
{file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"},
{file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"},
{file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"},
{file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"},
{file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"},
@ -2496,6 +2703,30 @@ h11 = ">=0.8"
[package.extras]
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[[package]]
name = "vine"
version = "5.0.0"
description = "Promises, promises, promises."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"},
{file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"},
]
[[package]]
name = "wcwidth"
version = "0.2.6"
description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"},
{file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"},
]
[[package]]
name = "wheel"
version = "0.40.0"
@ -2633,4 +2864,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e6e707098fa68cc228ba4bd454533f460f354991091144c0e3bde29ae0b409e1"
content-hash = "3a7cbc62858401d620de4eeab75906ea8caa97d780632e58e1d6316681228d13"

@ -24,6 +24,8 @@ uvicorn = "^0.21.1"
python-ptrace = "^0.9.8"
jinja2 = "^3.1.2"
python-multipart = "^0.0.6"
celery = "^5.2.7"
redis = "^4.5.4"
[tool.poetry.group.gpu]
optional = true

Loading…
Cancel
Save