Add SSH getHosts api

This commit is contained in:
max furman 2019-10-25 13:47:49 -07:00
parent ded8087042
commit 5616386eed
6 changed files with 98 additions and 12 deletions

View File

@ -257,6 +257,7 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("POST", "/ssh/config", h.SSHConfig) r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
r.MethodFunc("POST", "/ssh/get-hosts", h.SSHGetHosts)
// For compatibility with old code: // For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew) r.MethodFunc("POST", "/re-sign", h.Renew)

View File

@ -21,6 +21,7 @@ type SSHAuthority interface {
GetSSHFederation() (*authority.SSHKeys, error) GetSSHFederation() (*authority.SSHKeys, error)
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
CheckSSHHost(principal string) (bool, error) CheckSSHHost(principal string) (bool, error)
GetSSHHosts() ([]string, error)
} }
// SSHSignRequest is the request body of an SSH certificate request. // SSHSignRequest is the request body of an SSH certificate request.
@ -66,6 +67,11 @@ type SSHCertificate struct {
*ssh.Certificate `json:"omitempty"` *ssh.Certificate `json:"omitempty"`
} }
// SSHGetHostsResponse
type SSHGetHostsResponse struct {
Hosts []string `json:"hosts"`
}
// MarshalJSON implements the json.Marshaler interface. Returns a quoted, // MarshalJSON implements the json.Marshaler interface. Returns a quoted,
// base64 encoded, openssh wire format version of the certificate. // base64 encoded, openssh wire format version of the certificate.
func (c SSHCertificate) MarshalJSON() ([]byte, error) { func (c SSHCertificate) MarshalJSON() ([]byte, error) {
@ -369,3 +375,15 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
Exists: exists, Exists: exists,
}) })
} }
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts()
if err != nil {
WriteError(w, InternalServerError(err))
return
}
JSON(w, &SSHGetHostsResponse{
Hosts: hosts,
})
}

View File

@ -369,6 +369,16 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
return exists, nil return exists, nil
} }
// GetSSHHosts returns a list of valid host principals.
func (a *Authority) GetSSHHosts() ([]string, error) {
ps, err := a.db.GetSSHHostPrincipals()
if err != nil {
return nil, err
}
return ps, nil
}
func (a *Authority) getAddUserPrincipal() (cmd string) { func (a *Authority) getAddUserPrincipal() (cmd string) {
if a.config.SSH.AddUserPrincipal == "" { if a.config.SSH.AddUserPrincipal == "" {
return SSHAddUserPrincipal return SSHAddUserPrincipal

View File

@ -611,6 +611,24 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
return &check, nil return &check, nil
} }
// SSHGetHostPrincipals performs the POST /ssh/check-host request to the CA with the
// given principal.
func (c *Client) SSHGetHostPrincipals() (*api.SSHGetHostsResponse, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"})
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var hosts api.SSHGetHostsResponse
if err := readJSON(resp.Body, &hosts); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &hosts, nil
}
// RootFingerprint is a helper method that returns the current root fingerprint. // RootFingerprint is a helper method that returns the current root fingerprint.
// It does an health connection and gets the fingerprint from the TLS verified // It does an health connection and gets the fingerprint from the TLS verified
// chains. // chains.

View File

@ -14,12 +14,13 @@ import (
) )
var ( var (
certsTable = []byte("x509_certs") certsTable = []byte("x509_certs")
revokedCertsTable = []byte("revoked_x509_certs") revokedCertsTable = []byte("revoked_x509_certs")
usedOTTTable = []byte("used_ott") usedOTTTable = []byte("used_ott")
sshCertsTable = []byte("ssh_certs") sshCertsTable = []byte("ssh_certs")
sshHostsTable = []byte("ssh_hosts") sshHostsTable = []byte("ssh_hosts")
sshUsersTable = []byte("ssh_users") sshUsersTable = []byte("ssh_users")
sshHostPrincipalsTable = []byte("ssh_host_principals")
) )
// ErrAlreadyExists can be returned if the DB attempts to set a key that has // ErrAlreadyExists can be returned if the DB attempts to set a key that has
@ -42,6 +43,7 @@ type AuthDB interface {
UseToken(id, tok string) (bool, error) UseToken(id, tok string) (bool, error)
IsSSHHost(name string) (bool, error) IsSSHHost(name string) (bool, error)
StoreSSHCertificate(crt *ssh.Certificate) error StoreSSHCertificate(crt *ssh.Certificate) error
GetSSHHostPrincipals() ([]string, error)
Shutdown() error Shutdown() error
} }
@ -160,19 +162,32 @@ func (db *DB) IsSSHHost(principal string) (bool, error) {
return true, nil return true, nil
} }
type sshHostPrincipalData struct {
Serial string
Expiry uint64
}
// StoreSSHCertificate stores an SSH certificate. // StoreSSHCertificate stores an SSH certificate.
func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error { func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
var table []byte
serial := strconv.FormatUint(crt.Serial, 10) serial := strconv.FormatUint(crt.Serial, 10)
tx := new(database.Tx) tx := new(database.Tx)
tx.Set(sshCertsTable, []byte(serial), crt.Marshal()) tx.Set(sshCertsTable, []byte(serial), crt.Marshal())
if crt.CertType == ssh.HostCert { if crt.CertType == ssh.HostCert {
table = sshHostsTable for _, p := range crt.ValidPrincipals {
hostPrincipalData, err := json.Marshal(sshHostPrincipalData{
Serial: serial,
Expiry: crt.ValidBefore,
})
if err != nil {
return err
}
tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial))
tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData)
}
} else { } else {
table = sshUsersTable for _, p := range crt.ValidPrincipals {
} tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial))
for _, p := range crt.ValidPrincipals { }
tx.Set(table, []byte(strings.ToLower(p)), []byte(serial))
} }
if err := db.Update(tx); err != nil { if err := db.Update(tx); err != nil {
return errors.Wrap(err, "database Update error") return errors.Wrap(err, "database Update error")
@ -181,6 +196,25 @@ func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
} }
// GetSSHHostPrincipals gets a list of all valid host principals.
func (db *DB) GetSSHHostPrincipals() ([]string, error) {
entries, err := db.List(sshHostPrincipalsTable)
if err != nil {
return nil, err
}
var principals []string
for _, e := range entries {
var data sshHostPrincipalData
if err := json.Unmarshal(e.Value, &data); err != nil {
return nil, err
}
if time.Unix(int64(data.Expiry), 0).After(time.Now()) {
principals = append(principals, string(e.Key))
}
}
return principals, nil
}
// Shutdown sends a shutdown message to the database. // Shutdown sends a shutdown message to the database.
func (db *DB) Shutdown() error { func (db *DB) Shutdown() error {
if db.isUp { if db.isUp {

View File

@ -69,6 +69,11 @@ func (s *SimpleDB) StoreSSHCertificate(crt *ssh.Certificate) error {
return ErrNotImplemented return ErrNotImplemented
} }
// GetSSHHostPrincipals returns a "NotImplemented" error.
func (s *SimpleDB) GetSSHHostPrincipals() ([]string, error) {
return nil, ErrNotImplemented
}
// Shutdown returns nil // Shutdown returns nil
func (s *SimpleDB) Shutdown() error { func (s *SimpleDB) Shutdown() error {
return nil return nil