diff --git a/langchain/utilities/docker/__init__.py b/langchain/utilities/docker/__init__.py index c10ef35c..1325a777 100644 --- a/langchain/utilities/docker/__init__.py +++ b/langchain/utilities/docker/__init__.py @@ -1,30 +1,26 @@ """Wrapper for untrusted code exectuion on docker.""" -# TODO: Validation: -# - verify gVisor runtime (runsc) if available -# -TEST: spawned container: make sure it's ready ? before sending/reading commands -# - attach to running container -# - pass arbitrary image names -# - embed file payloads in the call to run (in LLMChain)? -# - image selection helper -# - LLMChain decorator ? +#TODO: attach to running container +#TODO: pull images +#TODO: embed file payloads in the call to run (in LLMChain)? +#TODO: image selection helper +#TODO: LLMChain decorator ? import docker import struct -import pandas as pd # type: ignore +import socket +import shlex +from time import sleep +import pandas as pd # type: ignore from docker.client import DockerClient # type: ignore -from docker.errors import APIError, ContainerError # type: ignore +from docker.errors import APIError, ContainerError # type: ignore import logging -from typing import Any, Dict -from typing import Optional -from pydantic import BaseModel, PrivateAttr, Extra, root_validator, validator +from .images import BaseImage, get_image_template, Python, Shell -logger = logging.getLogger(__name__) +from typing import Any, Dict, Optional, Union, Type +from pydantic import BaseModel, PrivateAttr, Extra, root_validator, validator, Field -docker_images = { - 'default': 'alpine:{version}', - 'python': 'python:{version}', - } +logger = logging.getLogger(__name__) SOCK_BUF_SIZE = 1024 @@ -55,9 +51,12 @@ _check_gvisor_runtime() class DockerSocket: """Wrapper around docker API's socket object. Can be used as a context manager.""" + _timeout: int = 10 - def __init__(self, socket): + + def __init__(self, socket, timeout: int = _timeout): self.socket = socket + self.socket._sock.settimeout(timeout) # self.socket._sock.setblocking(False) def __enter__(self): @@ -75,7 +74,10 @@ class DockerSocket: def sendall(self, data: bytes) -> None: self.socket._sock.sendall(data) - def recv(self): + def setblocking(self, flag: bool) -> None: + self.socket._sock.setblocking(flag) + + def recv(self) -> Any: """Wrapper for socket.recv that does buffured read.""" # NOTE: this is optional as a bonus @@ -149,17 +151,40 @@ class DockerSocket: return chunks +def _default_params() -> Dict: + return { + # the only required parameter to be able to attach. + 'stdin_open': True, + } +def _get_command(query: str, **kwargs: Dict) -> str: + """Build an escaped command from a query string and keyword arguments.""" + cmd = query + if 'default_command' in kwargs: + cmd = shlex.join([*kwargs.get('default_command'), query]) # type: ignore -class DockerWrapper(BaseModel, extra=Extra.forbid): - """Executes arbitrary payloads and returns the output.""" + return cmd - _docker_client: DockerClient = PrivateAttr() - image: Optional[str] = "alpine" +class DockerWrapper(BaseModel, extra=Extra.allow): + """Executes arbitrary payloads and returns the output. + + + Args: + image (str | Type[BaseImage]): Docker image to use for execution. The + image can be a string or a subclass of images.BaseImage. - # use env by default when create docker client - from_env: Optional[bool] = True + """ + _docker_client: DockerClient = PrivateAttr() + _params: Dict = Field(default_factory=Shell().dict(), skip=True) + image: Union[str, Type[BaseImage]] = Field(default_factory=Shell,skip=True) + from_env: Optional[bool] = Field(default=True, skip=True) + + # @property + # def image_name(self) -> str: + # """The image name that will be used when creating a container.""" + # return self._params.image + # def __init__(self, **kwargs): """Initialize docker client.""" super().__init__(**kwargs) @@ -167,6 +192,13 @@ class DockerWrapper(BaseModel, extra=Extra.forbid): if self.from_env: self._docker_client = docker.from_env() + # if not isinstance(self.image, str) and issubclass(self.image, BaseImage): + # self._params = {**self._params, **self.image().dict()} + # + # # if the user defined a custom image not pre registerd already we should + # # not use the custom command + # elif isinstance(self.image, str): + # self._params = {**_default_params(), **{'image': self.image}} @property def client(self) -> DockerClient: @@ -178,24 +210,87 @@ class DockerWrapper(BaseModel, extra=Extra.forbid): """Prints docker `info`.""" return self._docker_client.info() + # @validator("image", pre=True, always=True) + # def validate_image(cls, value): + # if value is None: + # raise ValueError("image is required") + # if isinstance(value, str) : + # image = get_image(value) + # if isinstance(image, BaseImage): + # return image + # else: + # #set default params to base ones + # if issubclass(value, BaseImage): + # return value + # else: + # raise ValueError("image must be a string or a subclass of images.BaseImage") + @root_validator() def validate_all(cls, values: Dict) -> Dict: """Validate environment.""" + image = values.get("image") + if image is None: + raise ValueError("image is required") + if isinstance(image, str): + # try to get image + _image = get_image_template(image) + if isinstance(_image, str): + # user wants a custom image, we should use default params + values["_params"] = {**_default_params(), **{'image': image}} + else: + # user wants a pre registered image, we should use the image params + values["_params"] = _image().dict() + # image is a BaseImage class + elif issubclass(image.__class__, BaseImage): + values["_params"] = image.dict() + + + def field_filter(x): + fields = cls.__fields__ + if x[0] == '_params': + return False + field = fields.get(x[0], None) + if not field: + return True + return not field.field_info.extra.get('skip', False) + filtered_fields: Dict[Any, Any] = dict(filter(field_filter, values.items())) # type: ignore + values["_params"] = {**values["_params"], + **filtered_fields} + return values def run(self, query: str, **kwargs: Any) -> str: """Run arbitrary shell command inside a container. + This method will concatenate the registered default command with the provided + query. + Args: + query (str): The command to run. **kwargs: Pass extra parameters to DockerClient.container.run. """ + kwargs = {**self._params, **kwargs} + args = { + 'image': self._params.get('image'), + 'command': query, + } + + del kwargs['image'] + cmd = _get_command(query, **kwargs) + kwargs.pop('default_command', None) + + args['command'] = cmd + # print(f"args: {args}") + # print(f"kwargs: {kwargs}") + # return + logger.debug(f"running command {args['command']}") + logger.debug(f"with params {kwargs}") try: - image = kwargs.get("image", self.image) - return self._docker_client.containers.run(image, - query, + result= self._docker_client.containers.run(*(args.values()), remove=True, **kwargs) + return result.decode('utf-8').strip() except ContainerError as e: return f"STDERR: {e}" @@ -206,11 +301,15 @@ class DockerWrapper(BaseModel, extra=Extra.forbid): - def exec_run(self, query: str, image: str) -> str: - """Run arbitrary shell command inside a container. + def exec_run(self, query: str, **kwargs: Any) -> str: + """Run arbitrary shell command inside an ephemeral container. - This is a lower level API that sends the input to the container's - stdin through a socket using Docker API. It effectively simulates a tty session. + This will create a container, run the command, and then remove the + container. the input is sent to the container's stdin through a socket + using Docker API. It effectively simulates a tty session. + + Args: + **kwargs: Pass extra parameters to DockerClient.container.exec_run. """ # it is necessary to open stdin to keep the container running after it's started # the attach_socket will hold the connection open until the container is stopped or @@ -222,40 +321,65 @@ class DockerWrapper(BaseModel, extra=Extra.forbid): # parameters to keep stdin open. For example python image needs to be # started with the command `python3 -i` + + kwargs = {**self._params, **kwargs} + if 'default_command' in kwargs: + kwargs['command'] = shlex.join(kwargs['default_command']) + del kwargs['default_command'] + # cmd = _get_command(query, **kwargs) + # kwargs.pop('default_command', None) + # kwargs['command'] = cmd + + # print(f"kwargs: {kwargs}") + # return + # TODO: handle both output mode for tty=True/False - container = self._docker_client.containers.create(image, stdin_open=True) + logger.debug(f"running command {kwargs['command']}") + logger.debug(f"with params {kwargs}") + container = self._docker_client.containers.create(**kwargs) container.start() - # input() # get underlying socket + # important to set 'stream' or attach API does not work _socket = container.attach_socket(params={'stdout': 1, 'stderr': 1, 'stdin': 1, 'stream': 1}) - with DockerSocket(_socket) as socket: - # TEST: make sure the container is ready ? use a blocking call first - socket.sendall(query.encode('utf-8')) + + # input() + with DockerSocket(_socket) as _socket: + # flush the output buffer (if any prompt) + flush = _socket.recv() + _socket.setblocking(True) + print(flush) + # TEST: make sure the container is ready ? use a blocking first call + _socket.sendall(query.encode('utf-8')) + #FIX: is it possible to know if command is finished ? + sleep(0.5) #this should be available as a parameter # read the output output = None - output = socket.recv() - # print(output) - + try: + output = _socket.recv() + except socket.timeout: + return "ERROR: timeout" container.kill() container.remove() + # output is stored in a list of tuples (stream_type, payload) df = pd.DataFrame(output, columns=['stream_type', 'payload']) df['payload'] = df['payload'].apply(lambda x: x.decode('utf-8')) df['stream_type'] = df['stream_type'].apply(lambda x: 'stdout' if x == 1 else 'stderr') payload = df.groupby('stream_type')['payload'].apply(''.join).to_dict() - print(payload) + logger.debug(f"payload: {payload}") + #HACK: better output handling when stderr is present + #NOTE: stderr might just contain the prompt if 'stdout' in payload and 'stderr' in payload: - return f"STDOUT:\n {payload['stdout']}\nSTDERR: {payload['stderr']}" - elif 'stderr' in payload: + return f"STDOUT:\n {payload['stdout']}\nSTDERR:\n {payload['stderr']}" + if 'stderr' in payload and not 'stdout' in payload: return f"STDERR: {payload['stderr']}" else: return payload['stdout'] - diff --git a/langchain/utilities/docker/images.py b/langchain/utilities/docker/images.py index eef0aa4e..a5f15775 100644 --- a/langchain/utilities/docker/images.py +++ b/langchain/utilities/docker/images.py @@ -1,18 +1,37 @@ -"""Optimized parameters for commonly used docker images that can be used by -the docker wrapper utility to attach to.""" +"""This module defines template images and halpers for common docker images.""" from enum import Enum -from typing import Optional, List +from typing import Optional, List, Type, Union from pydantic import BaseModel, Extra, validator + class BaseImage(BaseModel, extra=Extra.forbid): - """Base docker image class.""" + """Base docker image template class.""" tty: bool = False stdin_open: bool = True - image: str + name: str + tag: Optional[str] = 'latest' default_command: Optional[List[str]] = None + def dict(self, *args, **kwargs): + """Override the dict method to add the image name.""" + d = super().dict(*args, **kwargs) + del d['name'] + del d['tag'] + # del d['default_command'] + d['image'] = self.image_name + # if self.default_command: + # d['command'] = self.default_command + return d + + @property + def image_name(self) -> str: + """Image name.""" + return f'{self.name}:{self.tag}' + + + class ShellTypes(str, Enum): """Enum class for shell types.""" bash = '/bin/bash' @@ -26,10 +45,10 @@ class Shell(BaseImage): A shell image can be crated by passing a shell alias such as `sh` or `bash` or by passing the full path to the shell binary. """ - image: str = 'alpine' - shell: str = ShellTypes.bash.value + name: str = 'alpine' + default_command: List[str] = [ShellTypes.sh.value, '-c'] - @validator('shell') + @validator('default_command') def validate_shell(cls, value: str) -> str: """Validate shell type.""" val = getattr(ShellTypes, value, None) @@ -46,12 +65,38 @@ class Python(BaseImage): The python image needs to be launced using the `python3 -i` command to keep stdin open. """ - image: str = 'python' + name: str = 'python' default_command: List[str] = ['python3', '-i'] def __setattr__(self, name, value): if name == 'default_command': raise AttributeError(f'running this image with {self.default_command}' ' is necessary to keep stdin open.') - super().__setattr__(name, value) + + +def get_image_template(image_name: str = 'shell') -> Union[str, Type[BaseImage]]: + """Helper to get an image template from a string. + + It tries to find a class with the same name as the image name and returns the + class. If no class is found, it returns the image name. + + + .. code-block:: python + + >>> image = get_image_template('python') + >>> assert type(image) == Python + """ + import importlib + import inspect + + classes = inspect.getmembers(importlib.import_module(__name__), + lambda x: inspect.isclass(x) and x.__name__ == image_name.capitalize() + ) + + if classes: + cls = classes[0][1] + return cls + else: + return image_name + diff --git a/tests/unit_tests/test_docker.py b/tests/unit_tests/test_docker.py index 75ff8041..11fb471f 100644 --- a/tests/unit_tests/test_docker.py +++ b/tests/unit_tests/test_docker.py @@ -1,7 +1,8 @@ """Test the docker wrapper utility.""" import pytest -from langchain.utilities.docker import DockerWrapper, gvisor_runtime_available +from langchain.utilities.docker import DockerWrapper, \ + gvisor_runtime_available, _default_params from unittest.mock import MagicMock import subprocess @@ -21,11 +22,25 @@ def docker_installed() -> bool: @pytest.mark.skipif(not docker_installed(), reason="docker not installed") class TestDockerUtility: - def test_command_default_image(self) -> None: + def test_default_image(self) -> None: """Test running a command with the default alpine image.""" docker = DockerWrapper() output = docker.run('cat /etc/os-release') - assert output.find(b'alpine') + assert output.find('alpine') + + def test_shell_escaping(self) -> None: + docker = DockerWrapper() + output = docker.run('echo "hello world" | sed "s/world/you/g"') + assert output == 'hello you' + # using embedded quotes + output = docker.run("echo 'hello world' | awk '{print $2}'") + assert output == 'world' + + def test_auto_pull_image(self) -> None: + docker = DockerWrapper(image='golang:1.20') + output = docker.run("go version") + assert output.find('go1.20') + docker._docker_client.images.remove('golang:1.20') def test_inner_failing_command(self) -> None: """Test inner command with non zero exit""" @@ -46,3 +61,21 @@ class TestDockerUtility: assert gvisor_runtime_available(mock_client) mock_client.info.return_value = {'Runtimes': {'runc': {'path': 'runc'}}} assert not gvisor_runtime_available(mock_client) + + def test_socket_read_timeout(self) -> None: + """Test socket read timeout.""" + docker = DockerWrapper(image='python', command='python') + # this query should fail as python needs to be started with python3 -i + output = docker.exec_run("test query") + assert output == "ERROR: timeout" + +def test_get_image_template() -> None: + """Test getting an image template instance from string.""" + from langchain.utilities.docker.images import get_image_template + image = get_image_template("python") + assert image.__name__ == "Python" # type: ignore + +def test_default_params() -> None: + """Test default container parameters.""" + docker = DockerWrapper(image="my_custom_image") + assert docker._params == {**_default_params(), "image": "my_custom_image"}