diff --git a/ansi.py b/ansi.py new file mode 100644 index 0000000..d52f0d0 --- /dev/null +++ b/ansi.py @@ -0,0 +1,105 @@ +class Code: + def __init__(self, value: int): + self.value = value + + def __str__(self): + return "%d" % self.value + + +class Color(Code): + def bg(self) -> "Color": + self.value += 10 + return self + + def bright(self) -> "Color": + self.value += 60 + return self + + @staticmethod + def black() -> "Color": + return Color(30) + + @staticmethod + def red() -> "Color": + return Color(31) + + @staticmethod + def green() -> "Color": + return Color(32) + + @staticmethod + def yellow() -> "Color": + return Color(33) + + @staticmethod + def blue() -> "Color": + return Color(34) + + @staticmethod + def magenta() -> "Color": + return Color(35) + + @staticmethod + def cyan() -> "Color": + return Color(36) + + @staticmethod + def white() -> "Color": + return Color(37) + + @staticmethod + def default() -> "Color": + return Color(39) + + +class Style(Code): + @staticmethod + def reset() -> "Style": + return Style(0) + + @staticmethod + def bold() -> "Style": + return Style(1) + + @staticmethod + def dim() -> "Style": + return Style(2) + + @staticmethod + def italic() -> "Style": + return Style(3) + + @staticmethod + def underline() -> "Style": + return Style(4) + + @staticmethod + def blink() -> "Style": + return Style(5) + + @staticmethod + def reverse() -> "Style": + return Style(7) + + @staticmethod + def conceal() -> "Style": + return Style(8) + + +class ANSI: + ESCAPE = "\x1b[" + CLOSE = "m" + + def __init__(self, text: str): + self.text = text + self.args = [] + + def join(self) -> str: + return ANSI.ESCAPE + ";".join([str(a) for a in self.args]) + ANSI.CLOSE + + def wrap(self, text: str) -> str: + return self.join() + text + ANSI(Style.reset()).join() + + def to(self, *args: str): + self.args = list(args) + return self.wrap(self.text) diff --git a/api/main.py b/api/main.py index 4234cc9..92b6305 100644 --- a/api/main.py +++ b/api/main.py @@ -91,15 +91,10 @@ async def command(request: Request) -> Response: files = request.files session = request.key - logger.info("=============== Running =============") - logger.info(f"Query: {query}, Files: {files}") executor = agent_manager.get_or_create_executor(session) - logger.info(f"======> Previous memory:\n\t{executor.memory}") - promptedQuery = "\n".join([file_handler.handle(file) for file in files]) promptedQuery += query - logger.info(f"======> Prompted Text:\n\t{promptedQuery}") try: res = executor({"input": promptedQuery}) diff --git a/core/agents/callback.py b/core/agents/callback.py new file mode 100644 index 0000000..387ced5 --- /dev/null +++ b/core/agents/callback.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, List, Optional, Union + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + +from logger import logger +from ansi import ANSI, Color, Style + + +class EVALCallbackHandler(BaseCallbackHandler): + def dim_multiline(self, message: str) -> str: + return message.split("\n")[0] + ANSI( + "\n... ".join(["", *message.split("\n")[1:]]) + ).to(Color.black().bright()) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + pass + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + logger.info(ANSI(f"Entering new chain.").to(Color.green(), Style.italic())) + logger.info(ANSI("Prompted Text").to(Color.yellow()) + f': {inputs["input"]}') + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + logger.info(ANSI(f"Finished chain.").to(Color.green(), Style.italic())) + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logger.error(ANSI(f"Chain Error").to(Color.red()) + f": {error}") + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + pass + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + logger.info( + ANSI("Action").to(Color.cyan()) + ": " + ANSI(action.tool).to(Style.bold()) + ) + logger.info( + ANSI("Input").to(Color.cyan()) + + ": " + + self.dim_multiline(action.tool_input) + ) + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + logger.info( + ANSI("Observation").to(Color.magenta()) + ": " + self.dim_multiline(output) + ) + logger.info(ANSI("Thinking...").to(Color.green(), Style.italic())) + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logger.error(ANSI("Tool Error").to(Color.red()) + f": {error}") + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + pass + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + logger.info( + ANSI("Final Answer").to(Color.yellow()) + + ": " + + self.dim_multiline(finish.return_values.get("output", "")) + ) diff --git a/core/agents/manager.py b/core/agents/manager.py index ee653e8..52d500d 100644 --- a/core/agents/manager.py +++ b/core/agents/manager.py @@ -4,13 +4,20 @@ from langchain.agents.tools import BaseTool from langchain.agents.agent import Agent, AgentExecutor from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.memory.chat_memory import BaseChatMemory +from langchain.callbacks.base import CallbackManager +from langchain.callbacks import set_handler from core.tools.base import BaseToolSet from core.tools.factory import ToolsFactory +from .callback import EVALCallbackHandler from .builder import AgentBuilder +callback_manager = CallbackManager([EVALCallbackHandler()]) +set_handler(EVALCallbackHandler()) + + class AgentManager: def __init__( self, @@ -28,16 +35,21 @@ class AgentManager: def create_executor(self, session: str) -> AgentExecutor: memory: BaseChatMemory = self.create_memory() + tools = [ + *self.global_tools, + *ToolsFactory.create_per_session_tools( + self.toolsets, + get_session=lambda: (session, self.executors[session]), + ), + ] + for tool in tools: + tool.set_callback_manager(callback_manager) + return AgentExecutor.from_agent_and_tools( agent=self.agent, - tools=[ - *self.global_tools, - *ToolsFactory.create_per_session_tools( - self.toolsets, - get_session=lambda: (session, self.executors[session]), - ), - ], + tools=tools, memory=memory, + callback_manager=callback_manager, verbose=True, ) diff --git a/logger.py b/logger.py index 01aefd3..f080525 100644 --- a/logger.py +++ b/logger.py @@ -1,7 +1,12 @@ import logging from env import settings -logger = logging.getLogger("EVAL") +logger = logging.getLogger() +formatter = logging.Formatter("%(message)s") +ch = logging.StreamHandler() +ch.setFormatter(formatter) +logger.addHandler(ch) + if settings["LOG_LEVEL"] == "DEBUG": logger.setLevel(logging.DEBUG) else: