diff --git a/api/api.go b/api/api.go index 68334dcb..334def24 100644 --- a/api/api.go +++ b/api/api.go @@ -261,6 +261,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) + r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) diff --git a/api/api_test.go b/api/api_test.go index 1938e300..e68eb7db 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -557,6 +557,7 @@ type mockAuthority struct { getSSHFederation func() (*authority.SSHKeys, error) getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) checkSSHHost func(principal string) (bool, error) + getSSHBastion func(user string, hostname string) (*authority.Bastion, error) } // TODO: remove once Authorize is deprecated. @@ -711,6 +712,13 @@ func (m *mockAuthority) CheckSSHHost(principal string) (bool, error) { return m.ret1.(bool), m.err } +func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) { + if m.getSSHBastion != nil { + return m.getSSHBastion(user, hostname) + } + return m.ret1.(*authority.Bastion), m.err +} + func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority diff --git a/api/ssh.go b/api/ssh.go index 11fd3a89..7a2ba282 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -24,6 +24,7 @@ type SSHAuthority interface { GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(principal string) (bool, error) GetSSHHosts() ([]string, error) + GetSSHBastion(user string, hostname string) (*authority.Bastion, error) } // SSHSignRequest is the request body of an SSH certificate request. @@ -207,6 +208,28 @@ type SSHCheckPrincipalResponse struct { Exists bool `json:"exists"` } +// SSHBastionRequest is the request body used to get the bastion for a given +// host. +type SSHBastionRequest struct { + User string `json:"user"` + Hostname string `json:"hostname"` +} + +// Validate checks the values of the SSHBastionRequest. +func (r *SSHBastionRequest) Validate() error { + if r.Hostname == "" { + return errors.New("missing or empty hostname") + } + return nil +} + +// SSHBastionResponse is the response body used to return the bastion for a +// given host. +type SSHBastionResponse struct { + Hostname string `json:"hostname"` + Bastion *authority.Bastion `json:"bastion,omitempty"` +} + // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. @@ -392,3 +415,27 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { Hosts: hosts, }) } + +// SSHBastion provides returns the bastion configured if any. +func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { + var body SSHBastionRequest + if err := ReadJSON(r.Body, &body); err != nil { + WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + return + } + if err := body.Validate(); err != nil { + WriteError(w, BadRequest(err)) + return + } + + bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) + if err != nil { + WriteError(w, InternalServerError(err)) + return + } + + JSON(w, &SSHBastionResponse{ + Hostname: body.Hostname, + Bastion: bastion, + }) +} diff --git a/api/ssh_test.go b/api/ssh_test.go index e4e2fd9b..cc615ee7 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -537,6 +537,61 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } } +func Test_caHandler_SSHBastion(t *testing.T) { + bastion := &authority.Bastion{ + Hostname: "bastion.local", + } + bastionPort := &authority.Bastion{ + Hostname: "bastion.local", + Port: "2222", + } + + tests := []struct { + name string + bastion *authority.Bastion + bastionErr error + req []byte + body []byte + statusCode int + }{ + {"ok", bastion, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local"}}`), http.StatusOK}, + {"ok", bastionPort, nil, []byte(`{"hostname":"host.local","user":"user"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local","port":"2222"}}`), http.StatusOK}, + {"empty", nil, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local"}`), http.StatusOK}, + {"bad json", bastion, nil, []byte(`bad json`), nil, http.StatusBadRequest}, + {"bad request", bastion, nil, []byte(`{"hostname": ""}`), nil, http.StatusBadRequest}, + {"error", nil, fmt.Errorf("an error"), []byte(`{"hostname":"host.local"}`), nil, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + getSSHBastion: func(user, hostname string) (*authority.Bastion, error) { + return tt.bastion, tt.bastionErr + }, + }).(*caHandler) + + req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) + w := httptest.NewRecorder() + h.SSHBastion(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.SSHBastion unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.SSHBastion Body = %s, wants %s", body, tt.body) + } + } + }) + } +} + func TestSSHPublicKey_MarshalJSON(t *testing.T) { key, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) diff --git a/authority/authority.go b/authority/authority.go index 05a2e43a..091b84b9 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -39,6 +39,8 @@ type Authority struct { db db.AuthDB // Do not re-initialize initOnce bool + // Custom functions + sshBastionFunc func(user, hostname string) (*Bastion, error) } // New creates and initiates a new Authority type. diff --git a/authority/options.go b/authority/options.go index ebf6fe08..3d602255 100644 --- a/authority/options.go +++ b/authority/options.go @@ -14,3 +14,11 @@ func WithDatabase(db db.AuthDB) Option { a.db = db } } + +// WithSSHBastionFunc defines sets a custom function to get the bastion for a +// given user-host pair. +func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { + return func(a *Authority) { + a.sshBastionFunc = fn + } +} diff --git a/authority/ssh.go b/authority/ssh.go index 338f1da1..67c884b8 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -34,8 +34,18 @@ type SSHConfig struct { HostKey string `json:"hostKey"` UserKey string `json:"userKey"` Keys []*SSHPublicKey `json:"keys,omitempty"` - AddUserPrincipal string `json:"addUserPrincipal"` - AddUserCommand string `json:"addUserCommand"` + AddUserPrincipal string `json:"addUserPrincipal,omitempty"` + AddUserCommand string `json:"addUserCommand,omitempty"` + Bastion *Bastion `json:"bastion,omitempty"` +} + +// Bastion contains the custom properties used on bastion. +type Bastion struct { + Hostname string `json:"hostname"` + User string `json:"user,omitempty"` + Port string `json:"port,omitempty"` + Command string `json:"cmd,omitempty"` + Flags string `json:"flags,omitempty"` } // Validate checks the fields in SSHConfig. @@ -157,6 +167,24 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template return output, nil } +// GetSSHBastion returns the bastion configuration, for the given pair user, +// hostname. +func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) { + if a.sshBastionFunc != nil { + return a.sshBastionFunc(user, hostname) + } + if a.config.SSH != nil { + if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" { + return a.config.SSH.Bastion, nil + } + return nil, nil + } + return nil, &apiError{ + err: errors.New("getSSHBastion: ssh is not configured"), + code: http.StatusNotFound, + } +} + // authorizeSSHSign loads the provisioner from the token, checks that it has not // been used again and calls the provisioner AuthorizeSSHSign method. Returns a // list of methods to apply to the signing flow. diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 629bc3b4..c2f4ceb7 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" @@ -598,3 +599,48 @@ func TestSSHPublicKey_PublicKey(t *testing.T) { }) } } + +func TestAuthority_GetSSHBastion(t *testing.T) { + bastion := &Bastion{ + Hostname: "bastion.local", + Port: "2222", + } + type fields struct { + config *Config + sshBastionFunc func(user, hostname string) (*Bastion, error) + } + type args struct { + user string + hostname string + } + tests := []struct { + name string + fields fields + args args + want *Bastion + wantErr bool + }{ + {"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false}, + {"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false}, + {"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false}, + {"func", fields{&Config{}, func(_, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, + {"func err", fields{&Config{}, func(_, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, + {"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + sshBastionFunc: tt.fields.sshBastionFunc, + } + got, err := a.GetSSHBastion(tt.args.user, tt.args.hostname) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) + } + }) + } +}