mirror of
https://github.com/corca-ai/EVAL
synced 2024-10-30 09:20:44 +00:00
refactor: handlers and tools
This commit is contained in:
parent
2904d5fbf2
commit
83095ec2ce
143
agent.py
143
agent.py
@ -1,14 +1,15 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from llm import ChatOpenAI
|
||||
from langchain.agents import load_tools
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.initialize import initialize_agent
|
||||
from langchain.chains.conversation.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
|
||||
from utils import AWESOMEGPT_PREFIX, AWESOMEGPT_SUFFIX
|
||||
from prompts.input import AWESOMEGPT_PREFIX, AWESOMEGPT_SUFFIX
|
||||
|
||||
from tools.factory import ToolsFactory
|
||||
from tools.cpu import (
|
||||
Terminal,
|
||||
RequestsGet,
|
||||
@ -19,67 +20,97 @@ from tools.gpu import (
|
||||
ImageEditing,
|
||||
InstructPix2Pix,
|
||||
Text2Image,
|
||||
ImageCaptioning,
|
||||
VisualQuestionAnswering,
|
||||
)
|
||||
from handler import Handler, FileType
|
||||
from handlers.base import FileHandler, FileType
|
||||
from handlers.image import ImageCaptioning
|
||||
from handlers.dataframe import CsvToDataframe
|
||||
from env import settings
|
||||
|
||||
|
||||
def get_agent() -> Tuple[AgentExecutor, Handler]:
|
||||
print("Initializing AwesomeGPT")
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
class AgentFactory:
|
||||
def __init__(self):
|
||||
self.llm: BaseChatModel = None
|
||||
self.memory: BaseChatMemory = None
|
||||
self.tools: list = None
|
||||
self.handler: FileHandler = None
|
||||
|
||||
tool_names = ["python_repl", "wikipedia"]
|
||||
|
||||
if settings["SERPAPI_API_KEY"]:
|
||||
tool_names.append("serpapi")
|
||||
if settings["BING_SEARCH_URL"] and settings["BING_SUBSCRIPTION_KEY"]:
|
||||
tool_names.append("bing-search")
|
||||
tools = [*load_tools(tool_names, llm=llm)]
|
||||
|
||||
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||
|
||||
models = {
|
||||
"Terminal": Terminal(),
|
||||
"RequestsGet": RequestsGet(),
|
||||
"WineDB": WineDB(),
|
||||
"ExitConversation": ExitConversation(memory),
|
||||
"Text2Image": Text2Image("cuda"),
|
||||
"ImageEditing": ImageEditing("cuda"),
|
||||
"InstructPix2Pix": InstructPix2Pix("cuda"),
|
||||
"VisualQuestionAnswering": VisualQuestionAnswering("cuda"),
|
||||
}
|
||||
|
||||
for _, instance in models.items():
|
||||
for e in dir(instance):
|
||||
if e.startswith("inference"):
|
||||
func = getattr(instance, e)
|
||||
tools.append(
|
||||
Tool(name=func.name, description=func.description, func=func)
|
||||
)
|
||||
|
||||
handle_models: Dict[FileType, str] = {
|
||||
FileType.IMAGE: ImageCaptioning("cuda"),
|
||||
}
|
||||
|
||||
handler = Handler(
|
||||
handle_func={
|
||||
file_type: model.inference for file_type, model in handle_models.items()
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
def create(self):
|
||||
print("Initializing AwesomeGPT")
|
||||
self.create_llm()
|
||||
self.create_memory()
|
||||
self.create_tools()
|
||||
self.create_handler()
|
||||
return initialize_agent(
|
||||
self.tools,
|
||||
self.llm,
|
||||
agent="chat-conversational-react-description",
|
||||
verbose=True,
|
||||
memory=memory,
|
||||
memory=self.memory,
|
||||
agent_kwargs={
|
||||
"system_message": AWESOMEGPT_PREFIX,
|
||||
"human_message": AWESOMEGPT_SUFFIX,
|
||||
},
|
||||
),
|
||||
handler,
|
||||
)
|
||||
)
|
||||
|
||||
def create_llm(self):
|
||||
self.llm = ChatOpenAI(temperature=0)
|
||||
|
||||
def create_memory(self):
|
||||
self.memory = ConversationBufferMemory(
|
||||
memory_key="chat_history", return_messages=True
|
||||
)
|
||||
|
||||
def create_tools(self):
|
||||
if self.memory is None:
|
||||
raise ValueError("Memory must be initialized before tools")
|
||||
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized before tools")
|
||||
|
||||
toolnames = ["python_repl", "wikipedia"]
|
||||
|
||||
if settings["SERPAPI_API_KEY"]:
|
||||
toolnames.append("serpapi")
|
||||
if settings["BING_SEARCH_URL"] and settings["BING_SUBSCRIPTION_KEY"]:
|
||||
toolnames.append("bing-search")
|
||||
|
||||
toolsets = [
|
||||
Terminal(),
|
||||
RequestsGet(),
|
||||
ExitConversation(self.memory),
|
||||
Text2Image("cuda"),
|
||||
ImageEditing("cuda"),
|
||||
InstructPix2Pix("cuda"),
|
||||
VisualQuestionAnswering("cuda"),
|
||||
]
|
||||
|
||||
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
|
||||
toolsets.append(WineDB())
|
||||
|
||||
self.tools = [
|
||||
*ToolsFactory.from_names(toolnames, llm=self.llm),
|
||||
*ToolsFactory.from_toolsets(toolsets),
|
||||
]
|
||||
|
||||
def create_handler(self):
|
||||
self.handler = FileHandler(
|
||||
{
|
||||
FileType.IMAGE: ImageCaptioning("cuda"),
|
||||
FileType.DATAFRAME: CsvToDataframe(),
|
||||
}
|
||||
)
|
||||
|
||||
def get_handler(self):
|
||||
if self.handler is None:
|
||||
raise ValueError("Handler must be initialized before returning")
|
||||
|
||||
return self.handler
|
||||
|
||||
@staticmethod
|
||||
def get_agent_and_handler() -> Tuple[AgentExecutor, FileHandler]:
|
||||
factory = AgentFactory()
|
||||
agent = factory.create()
|
||||
handler = factory.get_handler()
|
||||
|
||||
return (agent, handler)
|
||||
|
89
handler.py
89
handler.py
@ -1,89 +0,0 @@
|
||||
import os
|
||||
import requests
|
||||
import uuid
|
||||
from typing import Callable, Dict
|
||||
from enum import Enum
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from utils import IMAGE_PROMPT, DATAFRAME_PROMPT
|
||||
|
||||
|
||||
class FileType(Enum):
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DATAFRAME = "dataframe"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class Handler:
|
||||
def __init__(self, handle_func: Dict[FileType, Callable]):
|
||||
self.handle_func = handle_func
|
||||
|
||||
def handle(self, i: int, file_name: str) -> str:
|
||||
"""
|
||||
Parse file type from file name (ex. image, audio, video, dataframe, etc.)
|
||||
"""
|
||||
file_type = file_name.split("?")[0]
|
||||
|
||||
if file_type.endswith(".png") or file_type.endswith(".jpg"):
|
||||
return self.handle_image(i, file_name)
|
||||
elif file_type.endswith(".mp3") or file_type.endswith(".wav"):
|
||||
return self.handle_audio(i, file_name)
|
||||
elif file_type.endswith(".mp4") or file_type.endswith(".avi"):
|
||||
return self.handle_video(i, file_name)
|
||||
elif file_type.endswith(".csv"):
|
||||
return self.handle_dataframe(i, file_name)
|
||||
else:
|
||||
return self.handle_unknown(i, file_name)
|
||||
|
||||
def handle_image(self, i: int, remote_filename: str) -> str:
|
||||
img_data = requests.get(remote_filename).content
|
||||
local_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png")
|
||||
with open(local_filename, "wb") as f:
|
||||
size = f.write(img_data)
|
||||
print(f"Inputs: {remote_filename} ({size//1000}MB) => {local_filename}")
|
||||
img = Image.open(local_filename)
|
||||
width, height = img.size
|
||||
ratio = min(512 / width, 512 / height)
|
||||
width_new, height_new = (round(width * ratio), round(height * ratio))
|
||||
img = img.resize((width_new, height_new))
|
||||
img = img.convert("RGB")
|
||||
img.save(local_filename, "PNG")
|
||||
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
||||
try:
|
||||
description = self.handle_func[FileType.IMAGE](local_filename)
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
return IMAGE_PROMPT.format(
|
||||
i=i, filename=local_filename, description=description
|
||||
)
|
||||
|
||||
def handle_audio(self, i: int, remote_filename: str) -> str:
|
||||
return ""
|
||||
|
||||
def handle_video(self, i: int, remote_filename: str) -> str:
|
||||
return ""
|
||||
|
||||
def handle_dataframe(self, i: int, remote_filename: str) -> str:
|
||||
content = requests.get(remote_filename).content
|
||||
local_filename = os.path.join("dataframe/", str(uuid.uuid4())[0:8] + ".csv")
|
||||
with open(local_filename, "wb") as f:
|
||||
size = f.write(content)
|
||||
print(f"Inputs: {remote_filename} ({size//1000}MB) => {local_filename}")
|
||||
df = pd.read_csv(local_filename)
|
||||
try:
|
||||
description = str(df.describe())
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
return DATAFRAME_PROMPT.format(
|
||||
i=i, filename=local_filename, description=description
|
||||
)
|
||||
|
||||
def handle_unknown(self, i: int, file: str) -> str:
|
||||
return ""
|
75
handlers/base.py
Normal file
75
handlers/base.py
Normal file
@ -0,0 +1,75 @@
|
||||
import os
|
||||
import requests
|
||||
import uuid
|
||||
from typing import Dict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FileType(Enum):
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DATAFRAME = "dataframe"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
@staticmethod
|
||||
def from_filename(url: str) -> "FileType":
|
||||
filename = url.split("?")[0]
|
||||
|
||||
if filename.endswith(".png") or filename.endswith(".jpg"):
|
||||
return FileType.IMAGE
|
||||
elif filename.endswith(".mp3") or filename.endswith(".wav"):
|
||||
return FileType.AUDIO
|
||||
elif filename.endswith(".mp4") or filename.endswith(".avi"):
|
||||
return FileType.VIDEO
|
||||
elif filename.endswith(".csv"):
|
||||
return FileType.DATAFRAME
|
||||
else:
|
||||
return FileType.UNKNOWN
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str) -> "FileType":
|
||||
return FileType.from_filename(url.split("?")[0])
|
||||
|
||||
def to_extension(self) -> str:
|
||||
if self == FileType.IMAGE:
|
||||
return ".png"
|
||||
elif self == FileType.AUDIO:
|
||||
return ".mp3"
|
||||
elif self == FileType.VIDEO:
|
||||
return ".mp4"
|
||||
elif self == FileType.DATAFRAME:
|
||||
return ".csv"
|
||||
else:
|
||||
return ".unknown"
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
def handle(self, filename: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileHandler:
|
||||
def __init__(self, handlers: Dict[FileType, BaseHandler]):
|
||||
self.handlers = handlers
|
||||
|
||||
def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler":
|
||||
self.handlers[filetype] = handler
|
||||
return self
|
||||
|
||||
def download(self, url: str) -> str:
|
||||
filetype = FileType.from_url(url)
|
||||
data = requests.get(url).content
|
||||
local_filename = os.path.join(
|
||||
filetype.value, str(uuid.uuid4())[0:8] + filetype.to_extension()
|
||||
)
|
||||
with open(local_filename, "wb") as f:
|
||||
size = f.write(data)
|
||||
print(f"Inputs: {url} ({size//1000}MB) => {local_filename}")
|
||||
return local_filename
|
||||
|
||||
def handle(self, url: str) -> str:
|
||||
try:
|
||||
return self.handlers[FileType.from_url(url)].handle(self.download(url))
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
11
handlers/dataframe.py
Normal file
11
handlers/dataframe.py
Normal file
@ -0,0 +1,11 @@
|
||||
import pandas as pd
|
||||
from prompts.file import DATAFRAME_PROMPT
|
||||
|
||||
from .base import BaseHandler
|
||||
|
||||
|
||||
class CsvToDataframe(BaseHandler):
|
||||
def handle(self, filename: str):
|
||||
df = pd.read_csv(filename)
|
||||
description = str(df.describe())
|
||||
return DATAFRAME_PROMPT.format(filename=filename, description=description)
|
43
handlers/image.py
Normal file
43
handlers/image.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
BlipProcessor,
|
||||
BlipForConditionalGeneration,
|
||||
)
|
||||
from prompts.file import IMAGE_PROMPT
|
||||
|
||||
from .base import BaseHandler
|
||||
|
||||
|
||||
class ImageCaptioning(BaseHandler):
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageCaptioning to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base"
|
||||
)
|
||||
self.model = BlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype
|
||||
).to(self.device)
|
||||
|
||||
def handle(self, filename: str):
|
||||
img = Image.open(filename)
|
||||
width, height = img.size
|
||||
ratio = min(512 / width, 512 / height)
|
||||
width_new, height_new = (round(width * ratio), round(height * ratio))
|
||||
img = img.resize((width_new, height_new))
|
||||
img = img.convert("RGB")
|
||||
img.save(filename, "PNG")
|
||||
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
||||
|
||||
inputs = self.processor(Image.open(filename), return_tensors="pt").to(
|
||||
self.device, self.torch_dtype
|
||||
)
|
||||
out = self.model.generate(**inputs)
|
||||
description = self.processor.decode(out[0], skip_special_tokens=True)
|
||||
print(
|
||||
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text: {description}"
|
||||
)
|
||||
|
||||
return IMAGE_PROMPT.format(filename=filename, description=description)
|
7
main.py
7
main.py
@ -5,12 +5,12 @@ from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from s3 import upload
|
||||
|
||||
from utils import ERROR_PROMPT
|
||||
from agent import get_agent
|
||||
from prompts.error import ERROR_PROMPT
|
||||
from agent import AgentFactory
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
agent, handler = get_agent()
|
||||
agent, handler = AgentFactory.get_agent_and_handler()
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
@ -42,7 +42,6 @@ async def command(request: Request) -> Response:
|
||||
print("======>Previous memory:\n %s" % agent.memory)
|
||||
|
||||
promptedQuery = ""
|
||||
import time
|
||||
|
||||
for i, file in enumerate(files):
|
||||
promptedQuery += handler.handle(i + 1, file)
|
||||
|
91
poetry.lock
generated
91
poetry.lock
generated
@ -224,6 +224,55 @@ soupsieve = ">1.2"
|
||||
html5lib = ["html5lib"]
|
||||
lxml = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "23.1.0"
|
||||
description = "The uncompromising code formatter."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "black-23.1.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221"},
|
||||
{file = "black-23.1.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26"},
|
||||
{file = "black-23.1.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b"},
|
||||
{file = "black-23.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"},
|
||||
{file = "black-23.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074"},
|
||||
{file = "black-23.1.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27"},
|
||||
{file = "black-23.1.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648"},
|
||||
{file = "black-23.1.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958"},
|
||||
{file = "black-23.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a"},
|
||||
{file = "black-23.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481"},
|
||||
{file = "black-23.1.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad"},
|
||||
{file = "black-23.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8"},
|
||||
{file = "black-23.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24"},
|
||||
{file = "black-23.1.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6"},
|
||||
{file = "black-23.1.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd"},
|
||||
{file = "black-23.1.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580"},
|
||||
{file = "black-23.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468"},
|
||||
{file = "black-23.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753"},
|
||||
{file = "black-23.1.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651"},
|
||||
{file = "black-23.1.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06"},
|
||||
{file = "black-23.1.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739"},
|
||||
{file = "black-23.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9"},
|
||||
{file = "black-23.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555"},
|
||||
{file = "black-23.1.0-py3-none-any.whl", hash = "sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32"},
|
||||
{file = "black-23.1.0.tar.gz", hash = "sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=8.0.0"
|
||||
mypy-extensions = ">=0.4.3"
|
||||
packaging = ">=22.0"
|
||||
pathspec = ">=0.9.0"
|
||||
platformdirs = ">=2"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
colorama = ["colorama (>=0.4.3)"]
|
||||
d = ["aiohttp (>=3.7.4)"]
|
||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||
uvloop = ["uvloop (>=0.15.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.26.94"
|
||||
@ -1359,6 +1408,18 @@ pytz = ">=2020.1"
|
||||
[package.extras]
|
||||
test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"]
|
||||
|
||||
[[package]]
|
||||
name = "pathspec"
|
||||
version = "0.11.1"
|
||||
description = "Utility library for gitignore style pattern matching of file paths."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"},
|
||||
{file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "9.4.0"
|
||||
@ -1450,6 +1511,22 @@ files = [
|
||||
docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"]
|
||||
tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "3.1.1"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "platformdirs-3.1.1-py3-none-any.whl", hash = "sha256:e5986afb596e4bb5bde29a79ac9061aa955b94fca2399b7aaac4090860920dd8"},
|
||||
{file = "platformdirs-3.1.1.tar.gz", hash = "sha256:024996549ee88ec1a9aa99ff7f8fc819bb59e2c3477b410d90a16d32d6e707aa"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
|
||||
test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "5.9.4"
|
||||
@ -2118,6 +2195,18 @@ dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
description = "A lil' TOML parser"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.0.0"
|
||||
@ -2493,4 +2582,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "e3fa682690d35abbaddcf5e2ab258f98f516b9e472cb9f6b6fcc184575d8fbfc"
|
||||
content-hash = "2cebdbe83e06dd233e889407b8bfba2e2a58eefc04323ad1adbb4cb8a31f8427"
|
||||
|
1
prompts/error.py
Normal file
1
prompts/error.py
Normal file
@ -0,0 +1 @@
|
||||
ERROR_PROMPT = "An error has occurred for the following text: \n{promptedQuery} Please explain this error.\n {e}"
|
25
prompts/file.py
Normal file
25
prompts/file.py
Normal file
@ -0,0 +1,25 @@
|
||||
IMAGE_PROMPT = """
|
||||
provide a figure named {filename}. The description is: {description}.
|
||||
|
||||
Please understand and answer the image based on this information. The image understanding is complete, so don't try to understand the image again.
|
||||
"""
|
||||
|
||||
|
||||
AUDIO_PROMPT = """
|
||||
provide a audio named {filename}. The description is: {description}.
|
||||
|
||||
Please understand and answer the audio based on this information. The audio understanding is complete, so don't try to understand the audio again.
|
||||
"""
|
||||
|
||||
VIDEO_PROMPT = """
|
||||
provide a video named {filename}. The description is: {description}.
|
||||
|
||||
Please understand and answer the video based on this information. The video understanding is complete, so don't try to understand the video again.
|
||||
"""
|
||||
|
||||
DATAFRAME_PROMPT = """
|
||||
provide a dataframe named {filename}. The description is: {description}.
|
||||
|
||||
You are able to use the dataframe to answer the question.
|
||||
You have to act like an data analyst who can do an effective analysis through dataframe.
|
||||
"""
|
33
prompts/input.py
Normal file
33
prompts/input.py
Normal file
@ -0,0 +1,33 @@
|
||||
AWESOMEGPT_PREFIX = """Awesome GPT is designed to be able to assist with a wide range of text, visual related tasks, data analysis related tasks, auditory related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics.
|
||||
Awesome GPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
||||
Awesome GPT is able to process and understand large amounts of various types of files(image, audio, video, dataframe, etc.). As a language model, Awesome GPT can not directly read various types of files(text, image, audio, video, dataframe, etc.), but it has a list of tools to finish different visual tasks.
|
||||
|
||||
Each image will have a file name formed as "image/xxx.png"
|
||||
Each audio will have a file name formed as "audio/xxx.mp3"
|
||||
Each video will have a file name formed as "video/xxx.mp4"
|
||||
Each dataframe will have a file name formed as "dataframe/xxx.csv"
|
||||
|
||||
Awesome GPT can invoke different tools to indirectly understand files(image, audio, video, dataframe, etc.). When talking about files(image, audio, video, dataframe, etc.), Awesome GPT is very strict to the file name and will never fabricate nonexistent files.
|
||||
When using tools to generate new files, Awesome GPT is also known that the file(image, audio, video, dataframe, etc.) may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real file.
|
||||
Awesome GPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the file content and file name. It will remember to provide the file name from the last tool observation, if a new file is generated.
|
||||
Human may provide new figures to Awesome GPT with a description. The description helps Awesome GPT to understand this file, but Awesome GPT should use tools to finish following tasks, rather than directly imagine from the description.
|
||||
|
||||
Overall, Awesome GPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics."""
|
||||
|
||||
AWESOMEGPT_SUFFIX = """TOOLS
|
||||
------
|
||||
Awesome GPT can ask the user to use tools to look up information that may be helpful in answering the users original question.
|
||||
You are very strict to the filename correctness and will never fake a file name if it does not exist.
|
||||
You will remember to provide the file name loyally if it's provided in the last tool observation.
|
||||
|
||||
The tools the human can use are:
|
||||
|
||||
{{tools}}
|
||||
|
||||
{format_instructions}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{{{{input}}}}"""
|
@ -30,6 +30,10 @@ psycopg2-binary = "^2.9.5"
|
||||
wikipedia = "^1.4.0"
|
||||
google-search-results = "^2.4.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
19
tools/base.py
Normal file
19
tools/base.py
Normal file
@ -0,0 +1,19 @@
|
||||
from langchain.agents.tools import Tool, BaseTool
|
||||
|
||||
|
||||
def tool(name, description):
|
||||
def decorator(func):
|
||||
func.name = name
|
||||
func.description = description
|
||||
func.is_tool = True
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BaseToolSet:
|
||||
def to_tools(cls) -> list[BaseTool]:
|
||||
method_tools = [getattr(cls, m) for m in dir(cls) if m.is_tool]
|
||||
return [
|
||||
Tool(name=m.name, description=m.description, func=m) for m in method_tools
|
||||
]
|
19
tools/cpu.py
19
tools/cpu.py
@ -1,4 +1,3 @@
|
||||
from utils import prompts
|
||||
from env import settings
|
||||
|
||||
import requests
|
||||
@ -12,8 +11,10 @@ from langchain.memory.chat_memory import BaseChatMemory
|
||||
import subprocess
|
||||
from typing import List, Union
|
||||
|
||||
from .base import tool, BaseToolSet
|
||||
|
||||
class Terminal:
|
||||
|
||||
class Terminal(BaseToolSet):
|
||||
"""Executes bash commands and returns the output."""
|
||||
|
||||
def __init__(self, strip_newlines: bool = False, return_err_output: bool = False):
|
||||
@ -21,7 +22,7 @@ class Terminal:
|
||||
self.strip_newlines = strip_newlines
|
||||
self.return_err_output = return_err_output
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Terminal",
|
||||
description="Executes commands in a terminal."
|
||||
"Input should be valid commands, "
|
||||
@ -49,8 +50,8 @@ class Terminal:
|
||||
return output
|
||||
|
||||
|
||||
class RequestsGet:
|
||||
@prompts(
|
||||
class RequestsGet(BaseToolSet):
|
||||
@tool(
|
||||
name="requests_get",
|
||||
description="A portal to the internet. "
|
||||
"Use this when you need to get specific content from a website."
|
||||
@ -66,7 +67,7 @@ class RequestsGet:
|
||||
return text
|
||||
|
||||
|
||||
class WineDB:
|
||||
class WineDB(BaseToolSet):
|
||||
def __init__(self):
|
||||
db = DatabaseReader(
|
||||
scheme="postgresql", # Database Scheme
|
||||
@ -87,7 +88,7 @@ class WineDB:
|
||||
documents = db.load_data(query=query)
|
||||
self.index = GPTSimpleVectorIndex(documents)
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Wine Recommendataion",
|
||||
description="A tool to recommend wines based on a user's input. "
|
||||
"Inputs are necessary factors for wine recommendations, such as the user's mood today, side dishes to eat with wine, people to drink wine with, what things you want to do, the scent and taste of their favorite wine."
|
||||
@ -108,11 +109,11 @@ class WineDB:
|
||||
return results.response + "\n\n" + wine
|
||||
|
||||
|
||||
class ExitConversation:
|
||||
class ExitConversation(BaseToolSet):
|
||||
def __init__(self, memory: BaseChatMemory):
|
||||
self.memory = memory
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="exit_conversation",
|
||||
description="A tool to exit the conversation. "
|
||||
"Use this when you want to end the conversation. "
|
||||
|
20
tools/factory.py
Normal file
20
tools/factory.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.agents import load_tools
|
||||
from langchain.agents.tools import BaseTool
|
||||
|
||||
from .base import BaseToolSet
|
||||
|
||||
|
||||
class ToolsFactory:
|
||||
@staticmethod
|
||||
def from_toolsets(toolsets: list[BaseToolSet]) -> list[BaseTool]:
|
||||
tools = []
|
||||
for toolset in toolsets:
|
||||
tools.extend(toolset.to_tools())
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]:
|
||||
return load_tools(toolnames, llm=llm)
|
53
tools/gpu.py
53
tools/gpu.py
@ -4,7 +4,7 @@ import uuid
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from utils import prompts, get_new_image_name
|
||||
from utils import get_new_image_name
|
||||
|
||||
from transformers import (
|
||||
CLIPSegProcessor,
|
||||
@ -23,8 +23,10 @@ from diffusers import (
|
||||
)
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
from .base import tool, BaseToolSet
|
||||
|
||||
class MaskFormer:
|
||||
|
||||
class MaskFormer(BaseToolSet):
|
||||
def __init__(self, device):
|
||||
print("Initializing MaskFormer to %s" % device)
|
||||
self.device = device
|
||||
@ -60,7 +62,7 @@ class MaskFormer:
|
||||
return image_mask.resize(original_image.size)
|
||||
|
||||
|
||||
class ImageEditing:
|
||||
class ImageEditing(BaseToolSet):
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageEditing to %s" % device)
|
||||
self.device = device
|
||||
@ -73,7 +75,7 @@ class ImageEditing:
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(device)
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Remove Something From The Photo",
|
||||
description="useful when you want to remove and object or something from the photo "
|
||||
"from its description or location. "
|
||||
@ -84,7 +86,7 @@ class ImageEditing:
|
||||
image_path, to_be_removed_txt = inputs.split(",")
|
||||
return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Replace Something From The Photo",
|
||||
description="useful when you want to replace an object from the object description or "
|
||||
"location with another object from its description. "
|
||||
@ -113,7 +115,7 @@ class ImageEditing:
|
||||
return updated_image_path
|
||||
|
||||
|
||||
class InstructPix2Pix:
|
||||
class InstructPix2Pix(BaseToolSet):
|
||||
def __init__(self, device):
|
||||
print("Initializing InstructPix2Pix to %s" % device)
|
||||
self.device = device
|
||||
@ -127,7 +129,7 @@ class InstructPix2Pix:
|
||||
self.pipe.scheduler.config
|
||||
)
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Instruct Image Using Text",
|
||||
description="useful when you want to the style of the image to be like the text. "
|
||||
"like: make it look like a painting. or make it like a robot. "
|
||||
@ -151,7 +153,7 @@ class InstructPix2Pix:
|
||||
return updated_image_path
|
||||
|
||||
|
||||
class Text2Image:
|
||||
class Text2Image(BaseToolSet):
|
||||
def __init__(self, device):
|
||||
print("Initializing Text2Image to %s" % device)
|
||||
self.device = device
|
||||
@ -166,7 +168,7 @@ class Text2Image:
|
||||
"fewer digits, cropped, worst quality, low quality"
|
||||
)
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Generate Image From User Input Text",
|
||||
description="useful when you want to generate an image from a user input text and save it to a file. "
|
||||
"like: generate an image of an object or something, or generate an image that includes some objects. "
|
||||
@ -183,36 +185,7 @@ class Text2Image:
|
||||
return image_filename
|
||||
|
||||
|
||||
class ImageCaptioning:
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageCaptioning to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base"
|
||||
)
|
||||
self.model = BlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype
|
||||
).to(self.device)
|
||||
|
||||
@prompts(
|
||||
name="Get Photo Description",
|
||||
description="useful when you want to know what is inside the photo. receives image_path as input. "
|
||||
"The input to this tool should be a string, representing the image_path. ",
|
||||
)
|
||||
def inference(self, image_path):
|
||||
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(
|
||||
self.device, self.torch_dtype
|
||||
)
|
||||
out = self.model.generate(**inputs)
|
||||
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
||||
print(
|
||||
f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}"
|
||||
)
|
||||
return captions
|
||||
|
||||
|
||||
class VisualQuestionAnswering:
|
||||
class VisualQuestionAnswering(BaseToolSet):
|
||||
def __init__(self, device):
|
||||
print("Initializing VisualQuestionAnswering to %s" % device)
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
@ -222,7 +195,7 @@ class VisualQuestionAnswering:
|
||||
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
|
||||
).to(self.device)
|
||||
|
||||
@prompts(
|
||||
@tool(
|
||||
name="Answer Question About The Image",
|
||||
description="useful when you need an answer for a question based on an image. "
|
||||
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
||||
|
65
utils.py
65
utils.py
@ -4,71 +4,6 @@ import torch
|
||||
import uuid
|
||||
import numpy as np
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
|
||||
|
||||
IMAGE_PROMPT = """
|
||||
{i}th file: provide a figure named {filename}. The description is: {description}.
|
||||
|
||||
Please understand and answer the image based on this information. The image understanding is complete, so don't try to understand the image again.
|
||||
"""
|
||||
|
||||
|
||||
AUDIO_PROMPT = """
|
||||
{i}th file: provide a audio named {filename}. The description is: {description}.
|
||||
|
||||
Please understand and answer the audio based on this information. The audio understanding is complete, so don't try to understand the audio again.
|
||||
"""
|
||||
|
||||
VIDEO_PROMPT = """
|
||||
{i}th file: provide a video named {filename}. The description is: {description}.
|
||||
|
||||
Please understand and answer the video based on this information. The video understanding is complete, so don't try to understand the video again.
|
||||
"""
|
||||
|
||||
DATAFRAME_PROMPT = """
|
||||
{i}th file: provide a dataframe named {filename}. The description is: {description}.
|
||||
|
||||
You are able to use the dataframe to answer the question.
|
||||
You have to act like an data analyst who can do an effective analysis through dataframe.
|
||||
"""
|
||||
|
||||
AWESOMEGPT_PREFIX = """Awesome GPT is designed to be able to assist with a wide range of text, visual related tasks, data analysis related tasks, auditory related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics.
|
||||
Awesome GPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
||||
Awesome GPT is able to process and understand large amounts of various types of files(image, audio, video, dataframe, etc.). As a language model, Awesome GPT can not directly read various types of files(text, image, audio, video, dataframe, etc.), but it has a list of tools to finish different visual tasks.
|
||||
|
||||
Each image will have a file name formed as "image/xxx.png"
|
||||
Each audio will have a file name formed as "audio/xxx.mp3"
|
||||
Each video will have a file name formed as "video/xxx.mp4"
|
||||
Each dataframe will have a file name formed as "dataframe/xxx.csv"
|
||||
|
||||
Awesome GPT can invoke different tools to indirectly understand files(image, audio, video, dataframe, etc.). When talking about files(image, audio, video, dataframe, etc.), Awesome GPT is very strict to the file name and will never fabricate nonexistent files.
|
||||
When using tools to generate new files, Awesome GPT is also known that the file(image, audio, video, dataframe, etc.) may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real file.
|
||||
Awesome GPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the file content and file name. It will remember to provide the file name from the last tool observation, if a new file is generated.
|
||||
Human may provide new figures to Awesome GPT with a description. The description helps Awesome GPT to understand this file, but Awesome GPT should use tools to finish following tasks, rather than directly imagine from the description.
|
||||
|
||||
Overall, Awesome GPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics."""
|
||||
|
||||
AWESOMEGPT_SUFFIX = """TOOLS
|
||||
------
|
||||
Awesome GPT can ask the user to use tools to look up information that may be helpful in answering the users original question.
|
||||
You are very strict to the filename correctness and will never fake a file name if it does not exist.
|
||||
You will remember to provide the file name loyally if it's provided in the last tool observation.
|
||||
|
||||
The tools the human can use are:
|
||||
|
||||
{{tools}}
|
||||
|
||||
{format_instructions}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{{{{input}}}}"""
|
||||
|
||||
ERROR_PROMPT = "An error has occurred for the following text: \n{promptedQuery} Please explain this error.\n {e}"
|
||||
|
||||
|
||||
os.makedirs("image", exist_ok=True)
|
||||
os.makedirs("audio", exist_ok=True)
|
||||
|
Loading…
Reference in New Issue
Block a user