refactor: BaseTracer

This commit is contained in:
hanchchch 2023-04-06 01:33:49 +09:00
parent 82f8126687
commit 4a3929b08a
5 changed files with 59 additions and 35 deletions

View File

@ -6,18 +6,15 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from ansi import ANSI, Color, Style, dim_multiline
from core.agents.manager import AgentManager
from core.handlers.base import BaseHandler, FileHandler, FileType
from core.handlers.dataframe import CsvToDataframe
from core.prompts.error import ERROR_PROMPT
from core.tools.base import BaseToolSet
from core.tools.cpu import ExitConversation, RequestsGet
from core.tools.editor import CodeEditor
from core.tools.terminal import Terminal
from core.upload import StaticUploader
from env import settings
from logger import logger
app = FastAPI()

View File

@ -1,13 +1,11 @@
import os
import subprocess
import time
from datetime import datetime
from typing import Dict, List
from ansi import ANSI, Color, Style
from core.tools.base import BaseToolSet, SessionGetter, ToolScope, tool
from core.tools.terminal.stdout import StdoutTracer
from core.tools.terminal.syscall import SyscallTracer
from core.workers.tracer.stdout import StdoutTracer
from core.workers.tracer.syscall import SyscallTracer
from env import settings
from logger import logger

View File

@ -0,0 +1,39 @@
import os
import time
import subprocess
from datetime import datetime
from typing import Callable, Literal, Optional, Union, Tuple
from abc import ABC, abstractmethod
PipeType = Union[Literal["stdout"], Literal["stderr"]]
OnOutputHandler = Callable[[PipeType, str], None]
class BaseTracer(ABC):
def __init__(
self,
process: subprocess.Popen,
on_output: OnOutputHandler = lambda: None,
):
self.process: subprocess.Popen = process
self.on_output: OnOutputHandler = on_output
os.set_blocking(self.process.stdout.fileno(), False)
os.set_blocking(self.process.stderr.fileno(), False)
def get_output(self, pipe: PipeType) -> str:
output = None
if pipe == "stdout":
output = self.process.stdout.read()
elif pipe == "stderr":
output = self.process.stderr.read()
if output:
decoded = output.decode()
self.on_output(pipe, decoded)
self.last_output = datetime.now()
return decoded
return ""
@abstractmethod
def wait_until_stop_or_exit(self) -> Tuple[Optional[int], str]:
pass

View File

@ -2,48 +2,30 @@ import os
import time
import subprocess
from datetime import datetime
from typing import Callable, Literal, Optional, Union, Tuple
from typing import Literal, Optional, Union, Tuple
from .base import BaseTracer, OnOutputHandler
PipeType = Union[Literal["stdout"], Literal["stderr"]]
class StdoutTracer:
class StdoutTracer(BaseTracer):
def __init__(
self,
process: subprocess.Popen,
timeout: int = 30,
on_output: OnOutputHandler = lambda: None,
interval: int = 0.1,
on_output: Callable[[PipeType, str], None] = lambda: None,
):
super().__init__(process, on_output)
self.process: subprocess.Popen = process
self.timeout: int = timeout
self.interval: int = interval
self.last_output: datetime = None
self.on_output: Callable[[PipeType, str], None] = on_output
def nonblock(self):
os.set_blocking(self.process.stdout.fileno(), False)
os.set_blocking(self.process.stderr.fileno(), False)
def get_output(self, pipe: PipeType) -> str:
output = None
if pipe == "stdout":
output = self.process.stdout.read()
elif pipe == "stderr":
output = self.process.stderr.read()
if output:
decoded = output.decode()
self.on_output(pipe, decoded)
self.last_output = datetime.now()
return decoded
return ""
def last_output_passed(self, seconds: int) -> bool:
return (datetime.now() - self.last_output).seconds > seconds
def wait_until_stop_or_exit(self) -> Tuple[Optional[int], str]:
self.nonblock()
self.last_output = datetime.now()
output = ""
exitcode = None
@ -61,7 +43,6 @@ class StdoutTracer:
break
if self.last_output_passed(self.timeout):
self.process.kill()
break
time.sleep(self.interval)

View File

@ -1,4 +1,5 @@
import signal
import subprocess
from typing import Optional, Tuple
from ptrace.debugger import (
@ -12,6 +13,7 @@ from ptrace.debugger import (
from ptrace.func_call import FunctionCallOptions
from ptrace.syscall import PtraceSyscall
from ptrace.tools import signal_to_exitcode
from .base import BaseTracer, OnOutputHandler
class SyscallTimeoutException(Exception):
@ -19,10 +21,17 @@ class SyscallTimeoutException(Exception):
super().__init__(f"deadline exceeded while waiting syscall for {pid}", *args)
class SyscallTracer:
def __init__(self, pid: int):
class SyscallTracer(BaseTracer):
def __init__(
self,
process: subprocess.Popen,
timeout: int = 30,
on_output: OnOutputHandler = lambda: None,
):
super().__init__(process, on_output)
self.debugger: PtraceDebugger = PtraceDebugger()
self.pid: int = pid
self.pid: int = process.pid
self.timeout: int = timeout
self.process: PtraceProcess = None
def is_waiting(self, syscall: PtraceSyscall) -> bool:
@ -61,7 +70,7 @@ class SyscallTracer:
break
try:
self.wait_syscall_with_timeout(30)
self.wait_syscall_with_timeout(self.timeout)
except ProcessExit as event:
if event.exitcode is not None:
exitcode = event.exitcode