refactor and add test
parent
afe3d23cbd
commit
16848ef29b
@ -0,0 +1,156 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type dependencies struct {
|
||||
// storing all these dependencies as fields for the sake of testing
|
||||
dialContext func(ctx context.Context, network, addr string) (io.Closer, error)
|
||||
startCmd func(*exec.Cmd) error
|
||||
tempDir func(dir string, pattern string) (name string, err error)
|
||||
getenv func(key string) string
|
||||
setenv func(key, value string) error
|
||||
}
|
||||
|
||||
type SSHHandler struct {
|
||||
deps dependencies
|
||||
}
|
||||
|
||||
func NewSSHHandler() *SSHHandler {
|
||||
return &SSHHandler{
|
||||
deps: dependencies{
|
||||
dialContext: func(ctx context.Context, network, addr string) (io.Closer, error) {
|
||||
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
||||
},
|
||||
startCmd: func(cmd *exec.Cmd) error { return cmd.Start() },
|
||||
tempDir: ioutil.TempDir,
|
||||
getenv: os.Getenv,
|
||||
setenv: os.Setenv,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSSHDockerHost overrides the DOCKER_HOST environment variable
|
||||
// to point towards a local unix socket tunneled over SSH to the specified ssh host.
|
||||
func (self *SSHHandler) HandleSSHDockerHost() (io.Closer, error) {
|
||||
const key = "DOCKER_HOST"
|
||||
ctx := context.Background()
|
||||
u, err := url.Parse(self.deps.getenv(key))
|
||||
if err != nil {
|
||||
// if no or an invalid docker host is specified, continue nominally
|
||||
return noopCloser{}, nil
|
||||
}
|
||||
|
||||
// if the docker host scheme is "ssh", forward the docker socket before creating the client
|
||||
if u.Scheme == "ssh" {
|
||||
tunnel, err := self.createDockerHostTunnel(ctx, u.Host)
|
||||
if err != nil {
|
||||
return noopCloser{}, fmt.Errorf("tunnel ssh docker host: %w", err)
|
||||
}
|
||||
err = self.deps.setenv(key, tunnel.socketPath)
|
||||
if err != nil {
|
||||
return noopCloser{}, fmt.Errorf("override DOCKER_HOST to tunneled socket: %w", err)
|
||||
}
|
||||
|
||||
return tunnel, nil
|
||||
}
|
||||
return noopCloser{}, nil
|
||||
}
|
||||
|
||||
type noopCloser struct{}
|
||||
|
||||
func (noopCloser) Close() error { return nil }
|
||||
|
||||
type tunneledDockerHost struct {
|
||||
socketPath string
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
var _ io.Closer = (*tunneledDockerHost)(nil)
|
||||
|
||||
func (t *tunneledDockerHost) Close() error {
|
||||
return syscall.Kill(-t.cmd.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
func (self *SSHHandler) createDockerHostTunnel(ctx context.Context, remoteHost string) (*tunneledDockerHost, error) {
|
||||
socketDir, err := self.deps.tempDir("/tmp", "lazydocker-sshtunnel-")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ssh tunnel tmp file: %w", err)
|
||||
}
|
||||
localSocket := path.Join(socketDir, "dockerhost.sock")
|
||||
|
||||
cmd, err := self.tunnelSSH(ctx, remoteHost, localSocket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tunnel docker host over ssh: %w", err)
|
||||
}
|
||||
|
||||
// set a reasonable timeout, then wait for the socket to dial successfully
|
||||
// before attempting to create a new docker client
|
||||
const socketTunnelTimeout = 8 * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, socketTunnelTimeout)
|
||||
defer cancel()
|
||||
|
||||
err = self.retrySocketDial(ctx, localSocket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssh tunneled socket never became available: %w", err)
|
||||
}
|
||||
|
||||
// construct the new DOCKER_HOST url with the proper scheme
|
||||
newDockerHostURL := url.URL{Scheme: "unix", Path: localSocket}
|
||||
return &tunneledDockerHost{
|
||||
socketPath: newDockerHostURL.String(),
|
||||
cmd: cmd,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Attempt to dial the socket until it becomes available.
|
||||
// The retry loop will continue until the parent context is canceled.
|
||||
func (self *SSHHandler) retrySocketDial(ctx context.Context, socketPath string) error {
|
||||
t := time.NewTicker(1 * time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
}
|
||||
// attempt to dial the socket, exit on success
|
||||
err := self.tryDial(ctx, socketPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try to dial the specified unix socket, immediately close the connection if successfully created.
|
||||
func (self *SSHHandler) tryDial(ctx context.Context, socketPath string) error {
|
||||
conn, err := self.deps.dialContext(ctx, "unix", socketPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (self *SSHHandler) tunnelSSH(ctx context.Context, host, localSocket string) (*exec.Cmd, error) {
|
||||
cmd := exec.CommandContext(ctx, "ssh", "-L", localSocket+":/var/run/docker.sock", host, "-N")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
err := self.deps.startCmd(cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
@ -0,0 +1,102 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSSHHandlerHandleSSHDockerHost(t *testing.T) {
|
||||
type scenario struct {
|
||||
testName string
|
||||
envVarValue string
|
||||
expectedDialContextCount int
|
||||
expectedStartCmdCount int
|
||||
}
|
||||
|
||||
scenarios := []scenario{
|
||||
{
|
||||
testName: "No env var set",
|
||||
envVarValue: "",
|
||||
expectedDialContextCount: 0,
|
||||
expectedStartCmdCount: 0,
|
||||
},
|
||||
{
|
||||
testName: "Env var set with https scheme",
|
||||
envVarValue: "https://myhost.com",
|
||||
expectedStartCmdCount: 0,
|
||||
expectedDialContextCount: 0,
|
||||
},
|
||||
{
|
||||
testName: "Env var set with ssh scheme",
|
||||
envVarValue: "ssh://myhost@192.168.5.178",
|
||||
expectedStartCmdCount: 1,
|
||||
expectedDialContextCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
s := s
|
||||
t.Run(s.testName, func(t *testing.T) {
|
||||
getenv := func(key string) string {
|
||||
if key != "DOCKER_HOST" {
|
||||
t.Errorf("Expected key to be DOCKER_HOST, got %s", key)
|
||||
}
|
||||
|
||||
return s.envVarValue
|
||||
}
|
||||
|
||||
tempDir := func(dir string, pattern string) (string, error) {
|
||||
assert.Equal(t, "/tmp", dir)
|
||||
assert.Equal(t, "lazydocker-sshtunnel-", pattern)
|
||||
|
||||
return "/tmp/lazydocker-ssh-tunnel-12345", nil
|
||||
}
|
||||
|
||||
setenv := func(key, value string) error {
|
||||
assert.Equal(t, "DOCKER_HOST", key)
|
||||
assert.Equal(t, "unix:///tmp/lazydocker-ssh-tunnel-12345/dockerhost.sock", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
startCmdCount := 0
|
||||
startCmd := func(cmd *exec.Cmd) error {
|
||||
assert.EqualValues(t, []string{"ssh", "-L", "/tmp/lazydocker-ssh-tunnel-12345/dockerhost.sock:/var/run/docker.sock", "192.168.5.178", "-N"}, cmd.Args)
|
||||
assert.Equal(t, true, cmd.SysProcAttr.Setpgid)
|
||||
|
||||
startCmdCount++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
dialContextCount := 0
|
||||
dialContext := func(ctx context.Context, network string, address string) (io.Closer, error) {
|
||||
assert.Equal(t, "unix", network)
|
||||
assert.Equal(t, "/tmp/lazydocker-ssh-tunnel-12345/dockerhost.sock", address)
|
||||
|
||||
dialContextCount++
|
||||
|
||||
return noopCloser{}, nil
|
||||
}
|
||||
|
||||
handler := &SSHHandler{
|
||||
deps: dependencies{
|
||||
dialContext: dialContext,
|
||||
startCmd: startCmd,
|
||||
tempDir: tempDir,
|
||||
getenv: getenv,
|
||||
setenv: setenv,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := handler.HandleSSHDockerHost()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, s.expectedDialContextCount, dialContextCount)
|
||||
assert.Equal(t, s.expectedStartCmdCount, startCmdCount)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue