From 51c6d6a4f2fc3c3fe4f69abede61c1139299f874 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 31 May 2022 16:12:25 -0700 Subject: [PATCH] Backport some changes from v0.20.0 to a new branch --- api/ssh.go | 2 + api/sshRekey.go | 2 + api/sshRenew.go | 2 + authority/authorize.go | 44 ++++++++++++------- authority/authorize_test.go | 40 ++++++++--------- authority/provisioner/method.go | 13 ++++++ authority/ssh.go | 78 +++++++++++++++++++++++++++++---- authority/tls.go | 29 +++++++----- db/db.go | 9 +++- db/simple.go | 2 +- 10 files changed, 163 insertions(+), 58 deletions(-) diff --git a/api/ssh.go b/api/ssh.go index 3b0de7c1..3bbd849a 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -288,6 +288,8 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) diff --git a/api/sshRekey.go b/api/sshRekey.go index 92278950..6d347945 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -59,6 +59,8 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) diff --git a/api/sshRenew.go b/api/sshRenew.go index 78d16fa6..5d659567 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -51,6 +51,8 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) diff --git a/authority/authorize.go b/authority/authorize.go index 7f9f456c..1e3eaad5 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "fmt" "net/http" "net/url" "strconv" @@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool { return m } -// authorizeToken parses the token and returns the provisioner used to generate -// the token. This method enforces the One-Time use policy (tokens can only be -// used once). -func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { - // Validate payload +// getProvisionerFromToken extracts a provisioner from the given token without +// doing any token validation. +func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) { tok, err := jose.ParseSigned(token) if err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token") + return nil, nil, fmt.Errorf("error parsing token: %w", err) } // Get claims w/out verification. We need to look up the provisioner @@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // before we can look up the provisioner. var claims Claims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") + return nil, nil, fmt.Errorf("error unmarshaling token: %w", err) + } + + // This method will also validate the audiences for JWK provisioners. + p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) + if !ok { + return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) + } + + return p, &claims, nil +} + +// authorizeToken parses the token and returns the provisioner used to generate +// the token. This method enforces the One-Time use policy (tokens can only be +// used once). +func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { + p, claims, err := a.getProvisionerFromToken(token) + if err != nil { + return nil, errs.UnauthorizedErr(err) } // TODO: use new persistence layer abstraction. @@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { - return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority") + return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority") } } - // This method will also validate the audiences for JWK provisioners. - p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) - if !ok { - return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+ - "not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) - } - // Store the token to protect against reuse unless it's skipped. // If we cannot get a token id from the provisioner, just hash the token. if !SkipTokenReuseFromContext(ctx) { @@ -189,10 +199,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error { ok, err := a.db.UseToken(reuseKey, token) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, - "authority.authorizeToken: failed when attempting to store token") + "failed when attempting to store token") } if !ok { - return errs.Unauthorized("authority.authorizeToken: token already used") + return errs.Unauthorized("token already used") } } return nil diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 087318be..13247311 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeToken: error parsing token"), + err: errors.New("error parsing token"), code: http.StatusUnauthorized, } }, @@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: raw, - err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"), + err: errors.New("token issued before the bootstrap of certificate authority"), code: http.StatusUnauthorized, } }, @@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: raw, - err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"), + err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"), code: http.StatusUnauthorized, } }, @@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: failed when attempting to store token: force"), + err: errors.New("failed when attempting to store token: force"), code: http.StatusInternalServerError, } }, @@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: context.Background(), - err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), - err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), - err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, @@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, @@ -988,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1082,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1190,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1282,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index f5cd5221..01dda2ed 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method { m, _ := ctx.Value(methodKey{}).(Method) return m } + +type tokenKey struct{} + +// NewContextWithToken creates a new context with the given token. +func NewContextWithToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, tokenKey{}, token) +} + +// TokenFromContext returns the token stored in the given context. +func TokenFromContext(ctx context.Context) (string, bool) { + token, ok := ctx.Value(tokenKey{}).(string) + return token, ok +} diff --git a/authority/ssh.go b/authority/ssh.go index 0521ab58..0bef73ae 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -161,6 +161,12 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Set backdate with the configured value opts.Backdate = a.config.AuthorityConfig.Backdate.Duration + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + for _, op := range signOpts { switch o := op.(type) { // add options to NewCertificate @@ -276,7 +282,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi } } - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeSSHCertificate(prov, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db") } @@ -298,6 +304,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, err } + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + backdate := a.config.AuthorityConfig.Backdate.Duration duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second now := time.Now() @@ -340,7 +352,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") } @@ -351,6 +363,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var validators []provisioner.SSHCertValidator + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + for _, op := range signOpts { switch o := op.(type) { // validate the ssh.Certificate @@ -419,21 +437,59 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } } - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") } return cert, nil } -func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error { +func (a *Authority) storeSSHCertificate(prov provisioner.Interface, cert *ssh.Certificate) error { type sshCertificateStorer interface { - StoreSSHCertificate(crt *ssh.Certificate) error + StoreSSHCertificate(provisioner.Interface, *ssh.Certificate) error } - if s, ok := a.adminDB.(sshCertificateStorer); ok { + + // Store certificate in admindb or linkedca + switch s := a.adminDB.(type) { + case sshCertificateStorer: + return s.StoreSSHCertificate(prov, cert) + case db.CertificateStorer: return s.StoreSSHCertificate(cert) } - return a.db.StoreSSHCertificate(cert) + + // Store certificate in localdb + switch s := a.db.(type) { + case sshCertificateStorer: + return s.StoreSSHCertificate(prov, cert) + case db.CertificateStorer: + return s.StoreSSHCertificate(cert) + default: + return nil + } +} + +func (a *Authority) storeRenewedSSHCertificate(prov provisioner.Interface, parent, cert *ssh.Certificate) error { + type sshRenewerCertificateStorer interface { + StoreRenewedSSHCertificate(p provisioner.Interface, parent, cert *ssh.Certificate) error + } + + // Store certificate in admindb or linkedca + switch s := a.adminDB.(type) { + case sshRenewerCertificateStorer: + return s.StoreRenewedSSHCertificate(prov, parent, cert) + case db.CertificateStorer: + return s.StoreSSHCertificate(cert) + } + + // Store certificate in localdb + switch s := a.db.(type) { + case sshRenewerCertificateStorer: + return s.StoreRenewedSSHCertificate(prov, parent, cert) + case db.CertificateStorer: + return s.StoreSSHCertificate(cert) + default: + return nil + } } // IsValidForAddUser checks if a user provisioner certificate can be issued to @@ -479,6 +535,12 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number") } + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + signer := a.sshCAUserCertSignKey principal := subject.ValidPrincipals[0] addUserPrincipal := a.getAddUserPrincipal() @@ -511,7 +573,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje } cert.Signature = sig - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, subject, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") } diff --git a/authority/tls.go b/authority/tls.go index d23b0da7..4c29ca15 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -365,28 +365,31 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 // `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(*x509.Certificate) error`. func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error { - type linkedChainStorer interface { + type certificateChainStorer interface { StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error } - type certificateChainStorer interface { + type certificateChainSimpleStorer interface { StoreCertificateChain(...*x509.Certificate) error } + // Store certificate in linkedca switch s := a.adminDB.(type) { - case linkedChainStorer: - return s.StoreCertificateChain(prov, fullchain...) case certificateChainStorer: + return s.StoreCertificateChain(prov, fullchain...) + case certificateChainSimpleStorer: return s.StoreCertificateChain(fullchain...) } // Store certificate in local db switch s := a.db.(type) { - case linkedChainStorer: - return s.StoreCertificateChain(prov, fullchain...) case certificateChainStorer: + return s.StoreCertificateChain(prov, fullchain...) + case certificateChainSimpleStorer: return s.StoreCertificateChain(fullchain...) + case db.CertificateStorer: + return s.StoreCertificate(fullchain[0]) default: - return a.db.StoreCertificate(fullchain[0]) + return nil } } @@ -398,15 +401,21 @@ func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain type renewedCertificateChainStorer interface { StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error } + // Store certificate in linkedca if s, ok := a.adminDB.(renewedCertificateChainStorer); ok { return s.StoreRenewedCertificate(oldCert, fullchain...) } + // Store certificate in local db - if s, ok := a.db.(renewedCertificateChainStorer); ok { + switch s := a.db.(type) { + case renewedCertificateChainStorer: return s.StoreRenewedCertificate(oldCert, fullchain...) + case db.CertificateStorer: + return s.StoreCertificate(fullchain[0]) + default: + return nil } - return a.db.StoreCertificate(fullchain[0]) } // RevokeOptions are the options for the Revoke API. @@ -551,7 +560,7 @@ func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateIn }); ok { return lca.RevokeSSH(crt, rci) } - return a.db.Revoke(rci) + return a.db.RevokeSSH(rci) } // GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server. diff --git a/db/db.go b/db/db.go index eccaf801..8cd1db0f 100644 --- a/db/db.go +++ b/db/db.go @@ -50,14 +50,19 @@ type AuthDB interface { Revoke(rci *RevokedCertificateInfo) error RevokeSSH(rci *RevokedCertificateInfo) error GetCertificate(serialNumber string) (*x509.Certificate, error) - StoreCertificate(crt *x509.Certificate) error UseToken(id, tok string) (bool, error) IsSSHHost(name string) (bool, error) - StoreSSHCertificate(crt *ssh.Certificate) error GetSSHHostPrincipals() ([]string, error) Shutdown() error } +// CertificateStorer is an extension of AuthDB that allows to store +// certificates. +type CertificateStorer interface { + StoreCertificate(crt *x509.Certificate) error + StoreSSHCertificate(crt *ssh.Certificate) error +} + // DB is a wrapper over the nosql.DB interface. type DB struct { nosql.DB diff --git a/db/simple.go b/db/simple.go index 0e5426ec..a7e38de9 100644 --- a/db/simple.go +++ b/db/simple.go @@ -20,7 +20,7 @@ type SimpleDB struct { usedTokens *sync.Map } -func newSimpleDB(c *Config) (AuthDB, error) { +func newSimpleDB(c *Config) (*SimpleDB, error) { db := &SimpleDB{} db.usedTokens = new(sync.Map) return db, nil