diff --git a/api/api.go b/api/api.go index 0284167f..ad8fbb98 100644 --- a/api/api.go +++ b/api/api.go @@ -257,6 +257,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("POST", "/ssh/config", h.SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) + r.MethodFunc("POST", "/ssh/get-hosts", h.SSHGetHosts) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) diff --git a/api/ssh.go b/api/ssh.go index e3101b8b..11d59712 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -21,6 +21,7 @@ type SSHAuthority interface { GetSSHFederation() (*authority.SSHKeys, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(principal string) (bool, error) + GetSSHHosts() ([]string, error) } // SSHSignRequest is the request body of an SSH certificate request. @@ -66,6 +67,11 @@ type SSHCertificate struct { *ssh.Certificate `json:"omitempty"` } +// SSHGetHostsResponse +type SSHGetHostsResponse struct { + Hosts []string `json:"hosts"` +} + // MarshalJSON implements the json.Marshaler interface. Returns a quoted, // base64 encoded, openssh wire format version of the certificate. func (c SSHCertificate) MarshalJSON() ([]byte, error) { @@ -369,3 +375,15 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { Exists: exists, }) } + +// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. +func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { + hosts, err := h.Authority.GetSSHHosts() + if err != nil { + WriteError(w, InternalServerError(err)) + return + } + JSON(w, &SSHGetHostsResponse{ + Hosts: hosts, + }) +} diff --git a/authority/ssh.go b/authority/ssh.go index 1c3f39bb..74833256 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -369,6 +369,16 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) { return exists, nil } +// GetSSHHosts returns a list of valid host principals. +func (a *Authority) GetSSHHosts() ([]string, error) { + ps, err := a.db.GetSSHHostPrincipals() + if err != nil { + return nil, err + } + + return ps, nil +} + func (a *Authority) getAddUserPrincipal() (cmd string) { if a.config.SSH.AddUserPrincipal == "" { return SSHAddUserPrincipal diff --git a/ca/client.go b/ca/client.go index 509ebb7c..8cefe4c0 100644 --- a/ca/client.go +++ b/ca/client.go @@ -611,6 +611,24 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, return &check, nil } +// SSHGetHostPrincipals performs the POST /ssh/check-host request to the CA with the +// given principal. +func (c *Client) SSHGetHostPrincipals() (*api.SSHGetHostsResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var hosts api.SSHGetHostsResponse + if err := readJSON(resp.Body, &hosts); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &hosts, nil +} + // RootFingerprint is a helper method that returns the current root fingerprint. // It does an health connection and gets the fingerprint from the TLS verified // chains. diff --git a/db/db.go b/db/db.go index 17e5c209..2aa093b4 100644 --- a/db/db.go +++ b/db/db.go @@ -14,12 +14,13 @@ import ( ) var ( - certsTable = []byte("x509_certs") - revokedCertsTable = []byte("revoked_x509_certs") - usedOTTTable = []byte("used_ott") - sshCertsTable = []byte("ssh_certs") - sshHostsTable = []byte("ssh_hosts") - sshUsersTable = []byte("ssh_users") + certsTable = []byte("x509_certs") + revokedCertsTable = []byte("revoked_x509_certs") + usedOTTTable = []byte("used_ott") + sshCertsTable = []byte("ssh_certs") + sshHostsTable = []byte("ssh_hosts") + sshUsersTable = []byte("ssh_users") + sshHostPrincipalsTable = []byte("ssh_host_principals") ) // ErrAlreadyExists can be returned if the DB attempts to set a key that has @@ -42,6 +43,7 @@ type AuthDB interface { UseToken(id, tok string) (bool, error) IsSSHHost(name string) (bool, error) StoreSSHCertificate(crt *ssh.Certificate) error + GetSSHHostPrincipals() ([]string, error) Shutdown() error } @@ -160,19 +162,32 @@ func (db *DB) IsSSHHost(principal string) (bool, error) { return true, nil } +type sshHostPrincipalData struct { + Serial string + Expiry uint64 +} + // StoreSSHCertificate stores an SSH certificate. func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error { - var table []byte serial := strconv.FormatUint(crt.Serial, 10) tx := new(database.Tx) tx.Set(sshCertsTable, []byte(serial), crt.Marshal()) if crt.CertType == ssh.HostCert { - table = sshHostsTable + for _, p := range crt.ValidPrincipals { + hostPrincipalData, err := json.Marshal(sshHostPrincipalData{ + Serial: serial, + Expiry: crt.ValidBefore, + }) + if err != nil { + return err + } + tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial)) + tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData) + } } else { - table = sshUsersTable - } - for _, p := range crt.ValidPrincipals { - tx.Set(table, []byte(strings.ToLower(p)), []byte(serial)) + for _, p := range crt.ValidPrincipals { + tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial)) + } } if err := db.Update(tx); err != nil { return errors.Wrap(err, "database Update error") @@ -181,6 +196,25 @@ func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error { } +// GetSSHHostPrincipals gets a list of all valid host principals. +func (db *DB) GetSSHHostPrincipals() ([]string, error) { + entries, err := db.List(sshHostPrincipalsTable) + if err != nil { + return nil, err + } + var principals []string + for _, e := range entries { + var data sshHostPrincipalData + if err := json.Unmarshal(e.Value, &data); err != nil { + return nil, err + } + if time.Unix(int64(data.Expiry), 0).After(time.Now()) { + principals = append(principals, string(e.Key)) + } + } + return principals, nil +} + // Shutdown sends a shutdown message to the database. func (db *DB) Shutdown() error { if db.isUp { diff --git a/db/simple.go b/db/simple.go index 7989de44..b0733d8d 100644 --- a/db/simple.go +++ b/db/simple.go @@ -69,6 +69,11 @@ func (s *SimpleDB) StoreSSHCertificate(crt *ssh.Certificate) error { return ErrNotImplemented } +// GetSSHHostPrincipals returns a "NotImplemented" error. +func (s *SimpleDB) GetSSHHostPrincipals() ([]string, error) { + return nil, ErrNotImplemented +} + // Shutdown returns nil func (s *SimpleDB) Shutdown() error { return nil