feat: execution tracing callback

pull/31/head
hanchchch 1 year ago
parent 4c075290d6
commit 94db570acf

@ -57,7 +57,7 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse:
files = request.files
session = request.session
executor = agent_manager.get_or_create_executor(session)
executor = agent_manager.create_executor(session)
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
promptedQuery += query
@ -67,7 +67,7 @@ async def execute(request: ExecuteRequest) -> ExecuteResponse:
except Exception as e:
return {"answer": str(e), "files": []}
files = re.findall(r"\[file/\S*\]", res["output"])
files = re.findall(r"\[file://\S*\]", res["output"])
files = [file[1:-1] for file in files]
return {
@ -96,7 +96,7 @@ async def execute_async(execution_id: str):
result = {}
if execution.status == "SUCCESS" and execution.result:
output = execution.result.get("output", "")
files = re.findall(r"\[file/\S*\]", output)
files = re.findall(r"\[file://\S*\]", output)
files = [file[1:-1] for file in files]
result = {
"answer": output,
@ -105,6 +105,7 @@ async def execute_async(execution_id: str):
return {
"status": execution.status,
"info": execution.info,
"result": result,
}

@ -8,6 +8,7 @@ celery_app = Celery(__name__)
celery_app.conf.broker_url = settings["CELERY_BROKER_URL"]
celery_app.conf.result_backend = settings["CELERY_BROKER_URL"]
celery_app.conf.update(
task_track_started=True,
task_serializer="json",
accept_content=["json"], # Ignore other content
result_serializer="json",
@ -15,10 +16,11 @@ celery_app.conf.update(
)
@celery_app.task(name="task_execute")
def task_execute(session: str, prompt: str):
executor = agent_manager.get_or_create_executor(session)
@celery_app.task(name="task_execute", bind=True)
def task_execute(self, session: str, prompt: str):
executor = agent_manager.create_executor(session, self)
response = executor({"input": prompt})
return {"output": response["output"]}

@ -4,6 +4,7 @@ from core.tools.factory import ToolsFactory
from env import settings
from langchain.chat_models.base import BaseChatModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.callbacks.base import BaseCallbackManager
from .chat_agent import ConversationalChatAgent
from .llm import ChatOpenAI
@ -17,8 +18,10 @@ class AgentBuilder:
self.global_tools: list = None
self.toolsets = toolsets
def build_llm(self):
self.llm = ChatOpenAI(temperature=0)
def build_llm(self, callback_manager: BaseCallbackManager = None):
self.llm = ChatOpenAI(
temperature=0, callback_manager=callback_manager, verbose=True
)
self.llm.check_access()
def build_parser(self):
@ -40,6 +43,12 @@ class AgentBuilder:
*ToolsFactory.create_global_tools(self.toolsets),
]
def get_parser(self):
if self.parser is None:
raise ValueError("Parser is not initialized yet")
return self.parser
def get_global_tools(self):
if self.global_tools is None:
raise ValueError("Global tools are not initialized yet")

@ -2,22 +2,45 @@ from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from celery import Task
from ansi import ANSI, Color, Style, dim_multiline
from logger import logger
class EVALCallbackHandler(BaseCallbackHandler):
@property
def ignore_llm(self) -> bool:
return False
def set_parser(self, parser) -> None:
self.parser = parser
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
text = response.generations[0][0].text
parsed = self.parser.parse_all(text)
logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + parsed["plan"])
logger.info(ANSI("What I Did").to(Color.blue()) + ": " + parsed["what_i_did"])
logger.info(
ANSI("Action").to(Color.cyan())
+ ": "
+ ANSI(parsed["action"]).to(Style.bold())
)
logger.info(
ANSI("Input").to(Color.cyan())
+ ": "
+ dim_multiline(parsed["action_input"])
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
pass
logger.info(ANSI(f"on_llm_new_token {token}").to(Color.green(), Style.italic()))
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@ -85,3 +108,87 @@ class EVALCallbackHandler(BaseCallbackHandler):
+ ": "
+ dim_multiline(finish.return_values.get("output", ""))
)
class ExecutionTracingCallbackHandler(BaseCallbackHandler):
def __init__(self, execution: Task):
self.execution = execution
def set_parser(self, parser) -> None:
self.parser = parser
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
text = response.generations[0][0].text
parsed = self.parser.parse_all(text)
self.execution.update_state(state="LLM_END", meta=parsed)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self.execution.update_state(state="CHAIN_ERROR", meta={"error": str(error)})
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
previous = self.execution.AsyncResult(self.execution.request.id)
self.execution.update_state(
state="TOOL_END", meta={**previous.info, "observation": output}
)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self.execution.update_state(state="TOOL_ERROR", meta={"error": str(error)})
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
self.execution.update_state(
state="AGENT_FINISH",
meta={"output": finish.return_values.get("output", "")},
)

@ -1,9 +1,9 @@
from typing import Dict
from typing import Dict, Optional
from celery import Task
from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.tools import BaseTool
from langchain.callbacks import set_handler
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import CallbackManager
from langchain.callbacks import set_handler
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory
@ -11,68 +11,73 @@ from core.tools.base import BaseToolSet
from core.tools.factory import ToolsFactory
from .builder import AgentBuilder
from .callback import EVALCallbackHandler
from .callback import EVALCallbackHandler, ExecutionTracingCallbackHandler
callback_manager = CallbackManager([EVALCallbackHandler()])
set_handler(EVALCallbackHandler())
class AgentManager:
def __init__(
self,
agent: Agent,
global_tools: list[BaseTool],
toolsets: list[BaseToolSet] = [],
):
self.agent: Agent = agent
self.global_tools: list[BaseTool] = global_tools
self.toolsets: list[BaseToolSet] = toolsets
self.memories: Dict[str, BaseChatMemory] = {}
self.executors: Dict[str, AgentExecutor] = {}
def create_memory(self) -> BaseChatMemory:
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def create_executor(self, session: str) -> AgentExecutor:
memory: BaseChatMemory = self.create_memory()
def get_or_create_memory(self, session: str) -> BaseChatMemory:
if not (session in self.memories):
self.memories[session] = self.create_memory()
return self.memories[session]
def create_executor(
self, session: str, execution: Optional[Task] = None
) -> AgentExecutor:
builder = AgentBuilder(self.toolsets)
builder.build_parser()
callbacks = []
eval_callback = EVALCallbackHandler()
eval_callback.set_parser(builder.get_parser())
callbacks.append(eval_callback)
if execution:
execution_callback = ExecutionTracingCallbackHandler(execution)
execution_callback.set_parser(builder.get_parser())
callbacks.append(execution_callback)
callback_manager = CallbackManager(callbacks)
builder.build_llm(callback_manager)
builder.build_global_tools()
memory: BaseChatMemory = self.get_or_create_memory(session)
tools = [
*self.global_tools,
*builder.get_global_tools(),
*ToolsFactory.create_per_session_tools(
self.toolsets,
get_session=lambda: (session, self.executors[session]),
),
]
for tool in tools:
tool.set_callback_manager(callback_manager)
tool.callback_manager = callback_manager
return AgentExecutor.from_agent_and_tools(
agent=self.agent,
executor = AgentExecutor.from_agent_and_tools(
agent=builder.get_agent(),
tools=tools,
memory=memory,
callback_manager=callback_manager,
verbose=True,
)
def remove_executor(self, session: str) -> None:
if session in self.executors:
del self.executors[session]
def get_or_create_executor(self, session: str) -> AgentExecutor:
if not (session in self.executors):
self.executors[session] = self.create_executor(session=session)
return self.executors[session]
self.executors[session] = executor
return executor
@staticmethod
def create(toolsets: list[BaseToolSet]) -> "AgentManager":
builder = AgentBuilder(toolsets)
builder.build_llm()
builder.build_parser()
builder.build_global_tools()
agent = builder.get_agent()
global_tools = builder.get_global_tools()
return AgentManager(
agent=agent,
global_tools=global_tools,
toolsets=toolsets,
)

@ -1,16 +1,31 @@
import re
import time
from typing import Dict
from langchain.output_parsers.base import BaseOutputParser
from ansi import ANSI, Color, Style
from core.agents.callback import dim_multiline
from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS
from logger import logger
class EvalOutputParser(BaseOutputParser):
@staticmethod
def parse_all(text: str) -> Dict[str, str]:
regex = r"Action: (.*?)[\n]Plan:(.*)[\n]What I Did:(.*)[\n]Action Input: (.*)"
match = re.search(regex, text, re.DOTALL)
if not match:
raise Exception("parse error")
action = match.group(1).strip()
plan = match.group(2)
what_i_did = match.group(3)
action_input = match.group(4).strip(" ").strip('"')
return {
"action": action,
"plan": plan,
"what_i_did": what_i_did,
"action_input": action_input,
}
def get_format_instructions(self) -> str:
return EVAL_FORMAT_INSTRUCTIONS
@ -20,21 +35,9 @@ class EvalOutputParser(BaseOutputParser):
if not match:
raise Exception("parse error")
action = match.group(1).strip()
plan = match.group(2)
what_i_did = match.group(3)
action_input = match.group(4)
logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + plan)
logger.info(ANSI("What I Did").to(Color.blue()) + ": " + what_i_did)
time.sleep(1)
logger.info(
ANSI("Action").to(Color.cyan()) + ": " + ANSI(action).to(Style.bold())
)
time.sleep(1)
logger.info(ANSI("Input").to(Color.cyan()) + ": " + dim_multiline(action_input))
time.sleep(1)
return {"action": action, "action_input": action_input.strip(" ").strip('"')}
parsed = EvalOutputParser.parse_all(text)
return {"action": parsed["action"], "action_input": parsed["action_input"]}
def __str__(self):
return "EvalOutputParser"

@ -3,7 +3,7 @@ EVAL_PREFIX = """{bot_name} can execute any user's request.
{bot_name} has permission to handle one instance and can handle the environment in it at will.
You can code, run, debug, and test yourself. You can correct the code appropriately by looking at the error message.
I can understand, process, and create various types of files. Every files except the code must be restored in file/ directory.
I can understand, process, and create various types of files.
{bot_name} can do whatever it takes to execute the user's request. Let's think step by step.
"""
@ -37,7 +37,7 @@ EVAL_SUFFIX = """TOOLS
{bot_name} can ask the user to use tools to look up information that may be helpful in answering the users original question.
You are very strict to the filename correctness and will never fake a file name if it does not exist.
You will remember to provide the file name loyally if it's provided in the last tool observation.
If you have to include files in your response, you must move the files into file/ directory and provide the filename in [file/FILENAME] format.
If you have to include files in your response, you must provide the filepath in [file://filepath] format.
The tools the human can use are:

Loading…
Cancel
Save