refactor: agent builder

pull/1/head
hanchchch 1 year ago
parent 83095ec2ce
commit 067379288b

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Dict, Tuple
from llm import ChatOpenAI
from langchain.agents.agent import AgentExecutor
@ -9,62 +9,28 @@ from langchain.chat_models.base import BaseChatModel
from prompts.input import AWESOMEGPT_PREFIX, AWESOMEGPT_SUFFIX
from tools.base import BaseToolSet
from tools.factory import ToolsFactory
from tools.cpu import (
Terminal,
RequestsGet,
WineDB,
ExitConversation,
)
from tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
from handlers.base import FileHandler, FileType
from handlers.image import ImageCaptioning
from handlers.dataframe import CsvToDataframe
from handlers.base import BaseHandler, FileHandler, FileType
from env import settings
class AgentFactory:
class AgentBuilder:
def __init__(self):
self.llm: BaseChatModel = None
self.memory: BaseChatMemory = None
self.tools: list = None
self.handler: FileHandler = None
def create(self):
print("Initializing AwesomeGPT")
self.create_llm()
self.create_memory()
self.create_tools()
self.create_handler()
return initialize_agent(
self.tools,
self.llm,
agent="chat-conversational-react-description",
verbose=True,
memory=self.memory,
agent_kwargs={
"system_message": AWESOMEGPT_PREFIX,
"human_message": AWESOMEGPT_SUFFIX,
},
)
def create_llm(self):
def build_llm(self):
self.llm = ChatOpenAI(temperature=0)
def create_memory(self):
def build_memory(self):
self.memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True
)
def create_tools(self):
if self.memory is None:
raise ValueError("Memory must be initialized before tools")
def build_tools(self, toolsets: list[BaseToolSet] = []):
if self.llm is None:
raise ValueError("LLM must be initialized before tools")
@ -75,30 +41,35 @@ class AgentFactory:
if settings["BING_SEARCH_URL"] and settings["BING_SUBSCRIPTION_KEY"]:
toolnames.append("bing-search")
toolsets = [
Terminal(),
RequestsGet(),
ExitConversation(self.memory),
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
toolsets.append(WineDB())
self.tools = [
*ToolsFactory.from_names(toolnames, llm=self.llm),
*ToolsFactory.from_toolsets(toolsets),
]
def create_handler(self):
self.handler = FileHandler(
{
FileType.IMAGE: ImageCaptioning("cuda"),
FileType.DATAFRAME: CsvToDataframe(),
}
def build_handler(self, handlers: Dict[FileType, BaseHandler]):
self.handler = FileHandler(handlers)
def get_agent(self):
print("Initializing AwesomeGPT")
if self.llm is None:
raise ValueError("LLM must be initialized before agent")
if self.tools is None:
raise ValueError("Tools must be initialized before agent")
if self.memory is None:
raise ValueError("Memory must be initialized before agent")
return initialize_agent(
self.tools,
self.llm,
agent="chat-conversational-react-description",
verbose=True,
memory=self.memory,
agent_kwargs={
"system_message": AWESOMEGPT_PREFIX,
"human_message": AWESOMEGPT_SUFFIX,
},
)
def get_handler(self):
@ -108,9 +79,16 @@ class AgentFactory:
return self.handler
@staticmethod
def get_agent_and_handler() -> Tuple[AgentExecutor, FileHandler]:
factory = AgentFactory()
agent = factory.create()
handler = factory.get_handler()
def get_agent_and_handler(
toolsets: list[BaseToolSet], handlers: Dict[FileType, BaseHandler]
) -> Tuple[AgentExecutor, FileHandler]:
builder = AgentBuilder()
builder.build_llm()
builder.build_memory()
builder.build_tools(toolsets)
builder.build_handler(handlers)
agent = builder.get_agent()
handler = builder.get_handler()
return (agent, handler)

@ -1,16 +1,53 @@
from typing import List, TypedDict
from typing import Dict, List, TypedDict
import re
from fastapi import FastAPI
from pydantic import BaseModel
from s3 import upload
from env import settings
from prompts.error import ERROR_PROMPT
from agent import AgentFactory
from agent import AgentBuilder
from tools.base import BaseToolSet
from tools.cpu import (
Terminal,
RequestsGet,
WineDB,
ExitConversation,
)
from tools.gpu import (
ImageEditing,
InstructPix2Pix,
Text2Image,
VisualQuestionAnswering,
)
from handlers.base import BaseHandler, FileType
from handlers.image import ImageCaptioning
from handlers.dataframe import CsvToDataframe
app = FastAPI()
agent, handler = AgentFactory.get_agent_and_handler()
toolsets: List[BaseToolSet] = [
Terminal(),
RequestsGet(),
ExitConversation(),
Text2Image("cuda"),
ImageEditing("cuda"),
InstructPix2Pix("cuda"),
VisualQuestionAnswering("cuda"),
]
handlers: Dict[FileType, BaseHandler] = {
FileType.IMAGE: ImageCaptioning("cuda"),
FileType.DATAFRAME: CsvToDataframe(),
}
if settings["WINEDB_HOST"] and settings["WINEDB_PASSWORD"]:
toolsets.append(WineDB())
agent, handler = AgentBuilder.get_agent_and_handler(
toolsets=toolsets, handlers=handlers
)
class Request(BaseModel):

@ -110,18 +110,15 @@ class WineDB(BaseToolSet):
class ExitConversation(BaseToolSet):
def __init__(self, memory: BaseChatMemory):
self.memory = memory
@tool(
name="exit_conversation",
description="A tool to exit the conversation. "
"Use this when you want to end the conversation. "
"Input should be a user's query."
"Input should be a user's query and user's session."
"The output will be a message that the conversation is over.",
)
def inference(self, query: str) -> str:
def inference(self, query: str, session: str) -> str:
"""Run the tool."""
self.memory.clear()
# session.clear() # TODO
return f"My original question was: {query}"

Loading…
Cancel
Save