diff --git a/api/ssh.go b/api/ssh.go index 7a2ba282..15c3c4b2 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -23,7 +23,7 @@ type SSHAuthority interface { GetSSHFederation() (*authority.SSHKeys, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(principal string) (bool, error) - GetSSHHosts() ([]string, error) + GetSSHHosts(user string) ([]string, error) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) } @@ -406,7 +406,18 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { - hosts, err := h.Authority.GetSSHHosts() + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + WriteError(w, BadRequest(errors.New("missing peer certificate"))) + return + } + + cert := r.TLS.PeerCertificates[0] + email := cert.EmailAddresses[0] + if len(email) == 0 { + WriteError(w, BadRequest(errors.New("client certificate missing email SAN"))) + return + } + hosts, err := h.Authority.GetSSHHosts(email) if err != nil { WriteError(w, InternalServerError(err)) return diff --git a/authority/authority.go b/authority/authority.go index 3177efd9..e00d978c 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -41,6 +41,7 @@ type Authority struct { initOnce bool // Custom functions sshBastionFunc func(user, hostname string) (*Bastion, error) + sshGetHostsFunc func(user string) ([]string, error) getIdentityFunc provisioner.GetIdentityFunc } diff --git a/authority/options.go b/authority/options.go index 409e8c2d..5a161118 100644 --- a/authority/options.go +++ b/authority/options.go @@ -16,6 +16,14 @@ func WithDatabase(db db.AuthDB) Option { } } +// WithGetIdentityFunc sets a custom function to retrieve the identity from +// an external resource. +func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { + return func(a *Authority) { + a.getIdentityFunc = fn + } +} + // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { @@ -24,10 +32,10 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { } } -// WithGetIdentityFunc sets a custom function to retrieve the identity from -// an external resource. -func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { +// WithSSHGetHosts sets a custom function to get the bastion for a +// given user-host pair. +func WithSSHGetHosts(fn func(user string) ([]string, error)) Option { return func(a *Authority) { - a.getIdentityFunc = fn + a.sshGetHostsFunc = fn } } diff --git a/authority/ssh.go b/authority/ssh.go index 67c884b8..36806ae6 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -673,13 +673,14 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) { } // GetSSHHosts returns a list of valid host principals. -func (a *Authority) GetSSHHosts() ([]string, error) { - ps, err := a.db.GetSSHHostPrincipals() - if err != nil { - return nil, err +func (a *Authority) GetSSHHosts(email string) ([]string, error) { + if a.sshBastionFunc != nil { + return a.sshGetHostsFunc(email) + } + return nil, &apiError{ + err: errors.New("getSSHHosts is not configured"), + code: http.StatusNotFound, } - - return ps, nil } func (a *Authority) getAddUserPrincipal() (cmd string) {