diff --git a/api/api.go b/api/api.go index 6c4f760d..0284167f 100644 --- a/api/api.go +++ b/api/api.go @@ -256,6 +256,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) r.MethodFunc("POST", "/ssh/config", h.SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) + r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) diff --git a/api/ssh.go b/api/ssh.go index e34174db..ad649e43 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -20,6 +20,7 @@ type SSHAuthority interface { GetSSHRoots() (*authority.SSHKeys, error) GetSSHFederation() (*authority.SSHKeys, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) + CheckSSHHost(principal string) (bool, error) } // SSHSignRequest is the request body of an SSH certificate request. @@ -170,6 +171,32 @@ type SSHConfigResponse struct { HostTemplates []Template `json:"hostTemplates,omitempty"` } +// SSHCheckPrincipalRequest is the request body used to check if a principal +// certificate has been created. Right now it only supported for hosts +// certificates. +type SSHCheckPrincipalRequest struct { + Type string `json:"type"` + Principal string `json:"principal"` +} + +// Validate checks the check principal request. +func (r *SSHCheckPrincipalRequest) Validate() error { + switch { + case r.Type != provisioner.SSHHostCert: + return errors.Errorf("unsupported type %s", r.Type) + case r.Principal == "": + return errors.New("missing or empty principal") + default: + return nil + } +} + +// SSHCheckPrincipalResponse is the response body used to check if a principal +// exists. +type SSHCheckPrincipalResponse struct { + Exists bool `json:"exists"` +} + // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. @@ -320,3 +347,25 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { JSON(w, config) } + +// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. +func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { + var body SSHCheckPrincipalRequest + if err := ReadJSON(r.Body, &body); err != nil { + WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + return + } + if err := body.Validate(); err != nil { + WriteError(w, BadRequest(err)) + return + } + + exists, err := h.Authority.CheckSSHHost(body.Principal) + if err != nil { + WriteError(w, InternalServerError(err)) + return + } + JSON(w, &SSHCheckPrincipalResponse{ + Exists: exists, + }) +} diff --git a/authority/ssh.go b/authority/ssh.go index 741d57cf..1c3f39bb 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/templates" "github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/jose" @@ -263,6 +264,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign } } + if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + return nil, &apiError{ + err: errors.Wrap(err, "signSSH: error storing certificate in db"), + code: http.StatusInternalServerError, + } + } + return cert, nil } @@ -276,13 +284,13 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) } if subject.CertType != ssh.UserCert { return nil, &apiError{ - err: errors.New("signSSHProxy: certificate is not a user certificate"), + err: errors.New("signSSHAddUser: certificate is not a user certificate"), code: http.StatusForbidden, } } if len(subject.ValidPrincipals) != 1 { return nil, &apiError{ - err: errors.New("signSSHProxy: certificate does not have only one principal"), + err: errors.New("signSSHAddUser: certificate does not have only one principal"), code: http.StatusForbidden, } } @@ -295,7 +303,7 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) var serial uint64 if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { return nil, &apiError{ - err: errors.Wrap(err, "signSSHProxy: error reading random number"), + err: errors.Wrap(err, "signSSHAddUser: error reading random number"), code: http.StatusInternalServerError, } } @@ -331,9 +339,36 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) return nil, err } cert.Signature = sig + + if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + return nil, &apiError{ + err: errors.Wrap(err, "signSSHAddUser: error storing certificate in db"), + code: http.StatusInternalServerError, + } + } + return cert, nil } +// CheckSSHHost checks the given principal has been registered before. +func (a *Authority) CheckSSHHost(principal string) (bool, error) { + exists, err := a.db.IsSSHHost(principal) + if err != nil { + if err == db.ErrNotImplemented { + return false, &apiError{ + err: errors.Wrap(err, "checkSSHHost: isSSHHost is not implemented"), + code: http.StatusNotImplemented, + } + } + return false, &apiError{ + err: errors.Wrap(err, "checkSSHHost: error checking if hosts exists"), + code: http.StatusInternalServerError, + } + } + + return exists, nil +} + func (a *Authority) getAddUserPrincipal() (cmd string) { if a.config.SSH.AddUserPrincipal == "" { return SSHAddUserPrincipal diff --git a/ca/client.go b/ca/client.go index 7f34a92c..160bfe52 100644 --- a/ca/client.go +++ b/ca/client.go @@ -21,6 +21,8 @@ import ( "strconv" "strings" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" @@ -585,6 +587,31 @@ func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, e return &config, nil } +// SSHCheckHost performs the POST /ssh/check-host request to the CA with the +// given principal. +func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) { + body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ + Type: provisioner.SSHHostCert, + Principal: principal, + }) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"}) + resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var check api.SSHCheckPrincipalResponse + if err := readJSON(resp.Body, &check); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &check, nil +} + // RootFingerprint is a helper method that returns the current root fingerprint. // It does an health connection and gets the fingerprint from the TLS verified // chains. diff --git a/db/db.go b/db/db.go index 9cf7031e..17e5c209 100644 --- a/db/db.go +++ b/db/db.go @@ -3,17 +3,23 @@ package db import ( "crypto/x509" "encoding/json" + "strconv" + "strings" "time" "github.com/pkg/errors" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" + "golang.org/x/crypto/ssh" ) var ( certsTable = []byte("x509_certs") revokedCertsTable = []byte("revoked_x509_certs") usedOTTTable = []byte("used_ott") + sshCertsTable = []byte("ssh_certs") + sshHostsTable = []byte("ssh_hosts") + sshUsersTable = []byte("ssh_users") ) // ErrAlreadyExists can be returned if the DB attempts to set a key that has @@ -34,6 +40,8 @@ type AuthDB interface { Revoke(rci *RevokedCertificateInfo) error StoreCertificate(crt *x509.Certificate) error UseToken(id, tok string) (bool, error) + IsSSHHost(name string) (bool, error) + StoreSSHCertificate(crt *ssh.Certificate) error Shutdown() error } @@ -55,7 +63,10 @@ func New(c *Config) (AuthDB, error) { return nil, errors.Wrapf(err, "Error opening database of Type %s with source %s", c.Type, c.DataSource) } - tables := [][]byte{revokedCertsTable, certsTable, usedOTTTable} + tables := [][]byte{ + revokedCertsTable, certsTable, usedOTTTable, + sshCertsTable, sshHostsTable, sshUsersTable, + } for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", @@ -138,6 +149,38 @@ func (db *DB) UseToken(id, tok string) (bool, error) { return swapped, nil } +// IsSSHHost returns if a principal is present in the ssh hosts table. +func (db *DB) IsSSHHost(principal string) (bool, error) { + if _, err := db.Get(sshHostsTable, []byte(strings.ToLower(principal))); err != nil { + if database.IsErrNotFound(err) { + return false, nil + } + return false, errors.Wrap(err, "database Get error") + } + return true, nil +} + +// StoreSSHCertificate stores an SSH certificate. +func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error { + var table []byte + serial := strconv.FormatUint(crt.Serial, 10) + tx := new(database.Tx) + tx.Set(sshCertsTable, []byte(serial), crt.Marshal()) + if crt.CertType == ssh.HostCert { + table = sshHostsTable + } else { + table = sshUsersTable + } + for _, p := range crt.ValidPrincipals { + tx.Set(table, []byte(strings.ToLower(p)), []byte(serial)) + } + if err := db.Update(tx); err != nil { + return errors.Wrap(err, "database Update error") + } + return nil + +} + // Shutdown sends a shutdown message to the database. func (db *DB) Shutdown() error { if db.isUp { diff --git a/db/simple.go b/db/simple.go index 30c2b124..7989de44 100644 --- a/db/simple.go +++ b/db/simple.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/nosql/database" + "golang.org/x/crypto/ssh" ) // ErrNotImplemented is an error returned when an operation is Not Implemented. @@ -58,6 +59,16 @@ func (s *SimpleDB) UseToken(id, tok string) (bool, error) { return true, nil } +// IsSSHHost returns a "NotImplemented" error. +func (s *SimpleDB) IsSSHHost(principal string) (bool, error) { + return false, ErrNotImplemented +} + +// StoreSSHCertificate returns a "NotImplemented" error. +func (s *SimpleDB) StoreSSHCertificate(crt *ssh.Certificate) error { + return ErrNotImplemented +} + // Shutdown returns nil func (s *SimpleDB) Shutdown() error { return nil