From 149fe0055ef5602b2635cef9313df7659acd6515 Mon Sep 17 00:00:00 2001 From: blob42 Date: Thu, 2 Mar 2023 20:39:48 +0100 Subject: [PATCH] exec_run fixes to keep stdin open --- langchain/utilities/docker/__init__.py | 44 +++++++++++++++++--------- langchain/utilities/docker/images.py | 4 +-- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/langchain/utilities/docker/__init__.py b/langchain/utilities/docker/__init__.py index 92c34703..59f01706 100644 --- a/langchain/utilities/docker/__init__.py +++ b/langchain/utilities/docker/__init__.py @@ -182,7 +182,7 @@ class DockerWrapper(BaseModel, extra=Extra.allow): default_command (List[str]): Default command to use when creating the container. """ - _docker_client: DockerClient = PrivateAttr() + _docker_client: DockerClient = None _params: Dict = Field(default_factory=Shell().dict(), skip=True) image: Union[str, Type[BaseImage]] = Field(default_factory=Shell,skip=True) from_env: Optional[bool] = Field(default=True, skip=True) @@ -309,9 +309,17 @@ class DockerWrapper(BaseModel, extra=Extra.allow): + def _flush_prompt(self, _socket): + flush = _socket.recv() + _socket.setblocking(True) + logger.debug(f"flushed output: {flush}") + + + def exec_run(self, query: str, timeout: int = 5, delay: float = 0.5, with_stderr: bool = False, + flush_prompt: bool = False, **kwargs: Any) -> str: """Run arbitrary shell command inside an ephemeral container. @@ -322,6 +330,8 @@ class DockerWrapper(BaseModel, extra=Extra.allow): Args: timeout (int): The timeout for receiving from the attached stdin. delay (int): The delay in seconds before running the command. + with_stderr (bool): If True, the stderr will be included in the output + flush_prompt (bool): If True, the prompt will be flushed before running the command. **kwargs: Pass extra parameters to DockerClient.container.exec_run. """ # it is necessary to open stdin to keep the container running after it's started @@ -342,7 +352,7 @@ class DockerWrapper(BaseModel, extra=Extra.allow): kwargs = {**self._params, **kwargs} if 'default_command' in kwargs: - kwargs['command'] = shlex.join(kwargs['default_command']) + kwargs['command'] = shlex.join([*kwargs['default_command'],query]) del kwargs['default_command'] # kwargs.pop('default_command', None) @@ -353,7 +363,7 @@ class DockerWrapper(BaseModel, extra=Extra.allow): # TODO: handle both output mode for tty=True/False logger.debug(f"running command {kwargs['command']}") - logger.debug(f"with params {kwargs}") + logger.debug(f"creating container with params {kwargs}") container = self._docker_client.containers.create(**kwargs) container.start() @@ -366,25 +376,29 @@ class DockerWrapper(BaseModel, extra=Extra.allow): # input() with DockerSocket(_socket, timeout=timeout) as _socket: # flush the output buffer (if any prompt) - output = None - try: - flush = _socket.recv() - _socket.setblocking(True) - logger.debug(f"flushed output: {flush}") - # TEST: make sure the container is ready ? use a blocking first call - _socket.sendall(query.encode('utf-8')) + if flush_prompt: + self._flush_prompt(_socket) + + # TEST: make sure the container is ready ? use a blocking first call + raw_input = f"{query}\n".encode('utf-8') + _socket.sendall(raw_input) - #NOTE: delay ensures that the command is executed after the input is sent - sleep(delay) #this should be available as a parameter + #NOTE: delay ensures that the command is executed after the input is sent + sleep(delay) #this should be available as a parameter - # read the output + # read the output + output = None + try: output = _socket.recv() except socket.timeout: return "ERROR: timeout" - container.kill() - container.remove() + try: + container.kill() + except APIError: + pass + container.remove(force=True) # output is stored in a list of tuples (stream_type, payload) diff --git a/langchain/utilities/docker/images.py b/langchain/utilities/docker/images.py index a5f15775..882fd626 100644 --- a/langchain/utilities/docker/images.py +++ b/langchain/utilities/docker/images.py @@ -46,7 +46,7 @@ class Shell(BaseImage): or by passing the full path to the shell binary. """ name: str = 'alpine' - default_command: List[str] = [ShellTypes.sh.value, '-c'] + default_command: List[str] = [ShellTypes.sh.value, '-s'] @validator('default_command') def validate_shell(cls, value: str) -> str: @@ -66,7 +66,7 @@ class Python(BaseImage): stdin open. """ name: str = 'python' - default_command: List[str] = ['python3', '-i'] + default_command: List[str] = ['python3', '-iq'] def __setattr__(self, name, value): if name == 'default_command':