diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 5ebb737c..c4390f03 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -47,7 +47,7 @@ type GCP struct { ServiceAccounts []string `json:"serviceAccounts"` Claims *Claims `json:"claims,omitempty"` claimer *Claimer - certStore *keyStore + keyStore *keyStore } // GetID returns the provisioner unique identifier. The name should uniquely @@ -103,8 +103,8 @@ func (p *GCP) Init(config Config) error { if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { return err } - // Initialize certificate store - p.certStore, err = newCertificateStore("https://www.googleapis.com/oauth2/v1/certs") + // Initialize key store + p.keyStore, err = newKeyStore("https://www.googleapis.com/oauth2/v3/certs") if err != nil { return err } @@ -185,15 +185,19 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { if len(jwt.Headers) == 0 { return nil, errors.New("error parsing token: header is missing") } - kid := jwt.Headers[0].KeyID - cert := p.certStore.GetCertificate(kid) - if cert == nil { - return nil, errors.Errorf("failed to validate payload: cannot find certificate for kid %s", kid) - } + var found bool var claims gcpPayload - if err = jwt.Claims(cert.PublicKey, &claims); err != nil { - return nil, errors.Wrap(err, "error parsing claims") + kid := jwt.Headers[0].KeyID + keys := p.keyStore.Get(kid) + for _, key := range keys { + if err := jwt.Claims(key, &claims); err == nil { + found = true + break + } + } + if !found { + return nil, errors.Errorf("failed to validate payload: cannot find certificate for kid %s", kid) } // According to "rfc7519 JSON Web Token" acceptable skew should be no diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index 89e50df2..2f11114a 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -1,9 +1,7 @@ package provisioner import ( - "crypto/x509" "encoding/json" - "encoding/pem" "math/rand" "net/http" "regexp" @@ -22,32 +20,13 @@ const ( var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)") -type oauth2Certificate struct { - ID string - Certificate *x509.Certificate -} - -type oauth2CertificateSet struct { - Certificates []oauth2Certificate -} - -func (s oauth2CertificateSet) Get(id string) *x509.Certificate { - for _, c := range s.Certificates { - if c.ID == id { - return c.Certificate - } - } - return nil -} - type keyStore struct { sync.RWMutex - uri string - keySet jose.JSONWebKeySet - certSet oauth2CertificateSet - timer *time.Timer - expiry time.Time - jitter time.Duration + uri string + keySet jose.JSONWebKeySet + timer *time.Timer + expiry time.Time + jitter time.Duration } func newKeyStore(uri string) (*keyStore, error) { @@ -66,22 +45,6 @@ func newKeyStore(uri string) (*keyStore, error) { return ks, nil } -func newCertificateStore(uri string) (*keyStore, error) { - certs, age, err := getOauth2Certificates(uri) - if err != nil { - return nil, err - } - ks := &keyStore{ - uri: uri, - certSet: certs, - expiry: getExpirationTime(age), - jitter: getCacheJitter(age), - } - next := ks.nextReloadDuration(age) - ks.timer = time.AfterFunc(next, ks.reloadCertificates) - return ks, nil -} - func (ks *keyStore) Close() { ks.timer.Stop() } @@ -99,19 +62,6 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { return } -func (ks *keyStore) GetCertificate(kid string) (cert *x509.Certificate) { - ks.RLock() - // Force reload if expiration has passed - if time.Now().After(ks.expiry) { - ks.RUnlock() - ks.reloadCertificates() - ks.RLock() - } - cert = ks.certSet.Get(kid) - ks.RUnlock() - return -} - func (ks *keyStore) reload() { var next time.Duration keys, age, err := getKeysFromJWKsURI(ks.uri) @@ -131,25 +81,6 @@ func (ks *keyStore) reload() { ks.Unlock() } -func (ks *keyStore) reloadCertificates() { - var next time.Duration - certs, age, err := getOauth2Certificates(ks.uri) - if err != nil { - next = ks.nextReloadDuration(ks.jitter / 2) - } else { - ks.Lock() - ks.certSet = certs - ks.expiry = getExpirationTime(age) - ks.jitter = getCacheJitter(age) - next = ks.nextReloadDuration(age) - ks.Unlock() - } - - ks.Lock() - ks.timer.Reset(next) - ks.Unlock() -} - func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { n := rand.Int63n(int64(ks.jitter)) age -= time.Duration(n) @@ -172,34 +103,6 @@ func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) { return keys, getCacheAge(resp.Header.Get("cache-control")), nil } -func getOauth2Certificates(uri string) (oauth2CertificateSet, time.Duration, error) { - var certs oauth2CertificateSet - resp, err := http.Get(uri) - if err != nil { - return certs, 0, errors.Wrapf(err, "failed to connect to %s", uri) - } - defer resp.Body.Close() - m := make(map[string]string) - if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { - return certs, 0, errors.Wrapf(err, "error reading %s", uri) - } - for k, v := range m { - block, _ := pem.Decode([]byte(v)) - if block == nil || block.Type != "CERTIFICATE" { - return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri) - } - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri) - } - certs.Certificates = append(certs.Certificates, oauth2Certificate{ - ID: k, - Certificate: cert, - }) - } - return certs, getCacheAge(resp.Header.Get("cache-control")), nil -} - func getCacheAge(cacheControl string) time.Duration { age := defaultCacheAge if len(cacheControl) > 0 {