diff --git a/ansi.py b/ansi.py index d52f0d0..43a5892 100644 --- a/ansi.py +++ b/ansi.py @@ -103,3 +103,10 @@ class ANSI: def to(self, *args: str): self.args = list(args) return self.wrap(self.text) + + +def dim_multiline(message: str) -> str: + lines = message.split("\n") + if len(lines) <= 1: + return lines[0] + return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright()) diff --git a/api/main.py b/api/main.py index 1372251..f950435 100644 --- a/api/main.py +++ b/api/main.py @@ -2,6 +2,11 @@ import re from typing import Dict, List, TypedDict import uvicorn +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel + +from ansi import ANSI, Color, Style, dim_multiline from core.agents.manager import AgentManager from core.handlers.base import BaseHandler, FileHandler, FileType from core.handlers.dataframe import CsvToDataframe @@ -12,10 +17,7 @@ from core.tools.editor import CodeEditor from core.tools.terminal import Terminal from core.upload import StaticUploader from env import settings -from fastapi import FastAPI -from fastapi.staticfiles import StaticFiles from logger import logger -from pydantic import BaseModel app = FastAPI() @@ -33,6 +35,7 @@ handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()} if settings["USE_GPU"]: import torch + from core.handlers.image import ImageCaptioning from core.tools.gpu import ( ImageEditing, @@ -86,23 +89,13 @@ async def command(request: Request) -> Response: try: res = executor({"input": promptedQuery}) except Exception as e: - logger.error(f"error while processing request: {str(e)}") - try: - res = executor( - { - "input": ERROR_PROMPT.format(promptedQuery=promptedQuery, e=str(e)), - } - ) - except Exception as e: - return {"response": str(e), "files": []} - - images = re.findall("(image/\S*png)", res["output"]) - dataframes = re.findall("(dataframe/\S*csv)", res["output"]) + return {"response": str(e), "files": []} + + files = re.findall("(image/\S*png)|(dataframe/\S*csv)", res["output"]) return { "response": res["output"], - "files": [uploader.upload(image) for image in images] - + [uploader.upload(dataframe) for dataframe in dataframes], + "files": [uploader.upload(file) for file in files], } diff --git a/core/agents/builder.py b/core/agents/builder.py index 14f43f2..cf7c5fc 100644 --- a/core/agents/builder.py +++ b/core/agents/builder.py @@ -1,10 +1,9 @@ -from langchain.chat_models.base import BaseChatModel -from langchain.output_parsers.base import BaseOutputParser - from core.prompts.input import EVAL_PREFIX, EVAL_SUFFIX from core.tools.base import BaseToolSet from core.tools.factory import ToolsFactory from env import settings +from langchain.chat_models.base import BaseChatModel +from langchain.output_parsers.base import BaseOutputParser from .chat_agent import ConversationalChatAgent from .llm import ChatOpenAI diff --git a/core/agents/callback.py b/core/agents/callback.py index c18d1f0..6769a14 100644 --- a/core/agents/callback.py +++ b/core/agents/callback.py @@ -1,15 +1,10 @@ from typing import Any, Dict, List, Optional, Union -from ansi import ANSI, Color, Style from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult -from logger import logger - -def dim_multiline(message: str) -> str: - return message.split("\n")[0] + ANSI( - "\n... ".join(["", *message.split("\n")[1:]]) - ).to(Color.black().bright()) +from ansi import ANSI, Color, Style, dim_multiline +from logger import logger class EVALCallbackHandler(BaseCallbackHandler): @@ -41,7 +36,9 @@ class EVALCallbackHandler(BaseCallbackHandler): def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - logger.error(ANSI(f"Chain Error").to(Color.red()) + f": {error}") + logger.error( + ANSI(f"Chain Error").to(Color.red()) + ": " + dim_multiline(str(error)) + ) def on_tool_start( self, diff --git a/core/agents/llm.py b/core/agents/llm.py index 27ab2f5..040c1a0 100644 --- a/core/agents/llm.py +++ b/core/agents/llm.py @@ -16,6 +16,7 @@ from langchain.schema import ( SystemMessage, ) from langchain.utils import get_from_dict_or_env +from logger import logger from pydantic import BaseModel, Extra, Field, root_validator from tenacity import ( before_sleep_log, @@ -25,8 +26,6 @@ from tenacity import ( wait_exponential, ) -from logger import logger - def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: import openai diff --git a/core/agents/parser.py b/core/agents/parser.py index e338060..b0f2714 100644 --- a/core/agents/parser.py +++ b/core/agents/parser.py @@ -2,10 +2,11 @@ import re import time from typing import Dict +from langchain.output_parsers.base import BaseOutputParser + from ansi import ANSI, Color, Style from core.agents.callback import dim_multiline from core.prompts.input import EVAL_FORMAT_INSTRUCTIONS -from langchain.output_parsers.base import BaseOutputParser from logger import logger @@ -24,8 +25,8 @@ class EvalOutputParser(BaseOutputParser): what_i_did = match.group(3) action_input = match.group(4) - logger.info(ANSI("Plan").to(Color.blue()) + ": " + plan) - logger.info(ANSI("What I Did").to(Color.blue().bright()) + ": " + what_i_did) + logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + plan) + logger.info(ANSI("What I Did").to(Color.blue()) + ": " + what_i_did) time.sleep(1) logger.info( ANSI("Action").to(Color.cyan()) + ": " + ANSI(action).to(Style.bold()) diff --git a/core/prompts/input.py b/core/prompts/input.py index ccd54b6..3c2aa33 100644 --- a/core/prompts/input.py +++ b/core/prompts/input.py @@ -28,8 +28,8 @@ You should replace sensitive data or encrypted data with "d1dy0uth1nk7hat1t1s7ha Your response should be in the following schema: Action: Final Answer -Plan: None -What I Did: None +Plan: ... +What I Did: ... Action Input: string \\ You should put what you want to return to use here. """ diff --git a/core/tools/editor/__init__.py b/core/tools/editor/__init__.py index 0ab839d..c572ff1 100644 --- a/core/tools/editor/__init__.py +++ b/core/tools/editor/__init__.py @@ -106,6 +106,12 @@ class CodeEditor(BaseToolSet): ) + "Each patch has to be formatted like below.\n" "|,|,|" + "Here is an example. If the original code is:\n" + "print('hello world')\n" + "and you want to change it to:\n" + "print('hi corca')\n" + "then the patch should be:\n" + "test.py|1,8|1,19|hi corca\n" "Code between start and end will be replaced with new_code. " "The output will be written/deleted bytes or error message. ", ) diff --git a/core/tools/editor/patch.py b/core/tools/editor/patch.py index 2753dec..a25eb74 100644 --- a/core/tools/editor/patch.py +++ b/core/tools/editor/patch.py @@ -57,6 +57,7 @@ test.py|11,16|11,16|_titles """ import os +import re from pathlib import Path from typing import Tuple @@ -70,10 +71,13 @@ class Position: self.line: int = line self.col: int = col + def __str__(self): + return f"(Ln {self.line}, Col {self.col})" + @staticmethod def from_str(pos: str) -> "Position": line, col = pos.split(Position.separator) - return Position(int(line), int(col)) + return Position(int(line) - 1, int(col) - 1) class PatchCommand: @@ -122,10 +126,22 @@ class PatchCommand: @staticmethod def from_str(command: str) -> "PatchCommand": - filepath, start, end = command.split(PatchCommand.separator)[:3] - content = command[len(filepath + start + end) + 3 :] + match = re.search( + r"(.*)\|([0-9])*,([0-9])*\|([0-9]*),([0-9]*)(\||\n)(.*)", + command, + re.DOTALL, + ) + filepath = match.group(1) + start_line = match.group(2) + start_col = match.group(3) + end_line = match.group(4) + end_col = match.group(5) + content = match.group(7) return PatchCommand( - filepath, Position.from_str(start), Position.from_str(end), content + filepath, + Position.from_str(f"{start_line},{start_col}"), + Position.from_str(f"{end_line},{end_col}"), + content, ) diff --git a/core/tools/terminal/__init__.py b/core/tools/terminal/__init__.py index 59d5b6d..b27104f 100644 --- a/core/tools/terminal/__init__.py +++ b/core/tools/terminal/__init__.py @@ -1,8 +1,12 @@ +import os import subprocess -from tempfile import TemporaryFile +import time +from datetime import datetime from typing import Dict, List +from ansi import ANSI, Color, Style from core.tools.base import BaseToolSet, SessionGetter, ToolScope, tool +from core.tools.terminal.stdout import StdoutTracer from core.tools.terminal.syscall import SyscallTracer from env import settings from logger import logger @@ -24,22 +28,23 @@ class Terminal(BaseToolSet): session, _ = get_session() try: - with TemporaryFile() as fp: - process = subprocess.Popen( - commands, - shell=True, - cwd=settings["PLAYGROUND_DIR"], - stdout=fp, - stderr=fp, - ) - - tracer = SyscallTracer(process.pid) - tracer.attach() - exitcode, reason = tracer.wait_until_stop_or_exit() - logger.debug(f"Stopped terminal execution: {exitcode} {reason}") - - fp.seek(0) - output = fp.read().decode() + process = subprocess.Popen( + commands, + shell=True, + cwd=settings["PLAYGROUND_DIR"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + logger.info(ANSI("Realtime Terminal Output").to(Color.magenta()) + ": ") + + output = "" + tracer = StdoutTracer( + process, + on_output=lambda p, o: logger.info( + ANSI(p).to(Style.dim()) + " " + o.strip("\n") + ), + ) + exitcode, output = tracer.wait_until_stop_or_exit() except Exception as e: output = str(e) diff --git a/core/tools/terminal/stdout.py b/core/tools/terminal/stdout.py new file mode 100644 index 0000000..ae665e3 --- /dev/null +++ b/core/tools/terminal/stdout.py @@ -0,0 +1,69 @@ +import os +import time +import subprocess +from datetime import datetime +from typing import Callable, Literal, Optional, Union, Tuple + +PipeType = Union[Literal["stdout"], Literal["stderr"]] + + +class StdoutTracer: + def __init__( + self, + process: subprocess.Popen, + timeout: int = 30, + interval: int = 0.1, + on_output: Callable[[PipeType, str], None] = lambda: None, + ): + self.process: subprocess.Popen = process + self.timeout: int = timeout + self.interval: int = interval + self.last_output: datetime = None + self.on_output: Callable[[PipeType, str], None] = on_output + + def nonblock(self): + os.set_blocking(self.process.stdout.fileno(), False) + os.set_blocking(self.process.stderr.fileno(), False) + + def get_output(self, pipe: PipeType) -> str: + output = None + if pipe == "stdout": + output = self.process.stdout.read() + elif pipe == "stderr": + output = self.process.stderr.read() + + if output: + decoded = output.decode() + self.on_output(pipe, decoded) + self.last_output = datetime.now() + return decoded + return "" + + def last_output_passed(self, seconds: int) -> bool: + return (datetime.now() - self.last_output).seconds > seconds + + def wait_until_stop_or_exit(self) -> Tuple[Optional[int], str]: + self.nonblock() + self.last_output = datetime.now() + output = "" + exitcode = None + while True: + new_stdout = self.get_output("stdout") + if new_stdout: + output += new_stdout + + new_stderr = self.get_output("stderr") + if new_stderr: + output += new_stderr + + if self.process.poll() is not None: + exitcode = self.process.poll() + break + + if self.last_output_passed(self.timeout): + self.process.kill() + break + + time.sleep(self.interval) + + return (exitcode, output)