package provisioner import ( "encoding/json" "math/rand" "net/http" "regexp" "strconv" "sync" "time" "github.com/pkg/errors" "go.step.sm/crypto/jose" ) const ( defaultCacheAge = 12 * time.Hour defaultCacheJitter = 1 * time.Hour ) var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)") type keyStore struct { sync.RWMutex uri string keySet jose.JSONWebKeySet timer *time.Timer expiry time.Time jitter time.Duration } func newKeyStore(uri string) (*keyStore, error) { keys, age, err := getKeysFromJWKsURI(uri) if err != nil { return nil, err } ks := &keyStore{ uri: uri, keySet: keys, expiry: getExpirationTime(age), jitter: getCacheJitter(age), } next := ks.nextReloadDuration(age) ks.timer = time.AfterFunc(next, ks.reload) return ks, nil } func (ks *keyStore) Close() { ks.timer.Stop() } func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { ks.RLock() // Force reload if expiration has passed if time.Now().After(ks.expiry) { ks.RUnlock() ks.reload() ks.RLock() } keys = ks.keySet.Key(kid) ks.RUnlock() return } func (ks *keyStore) reload() { var next time.Duration keys, age, err := getKeysFromJWKsURI(ks.uri) if err != nil { next = ks.nextReloadDuration(ks.jitter / 2) } else { ks.Lock() ks.keySet = keys ks.expiry = getExpirationTime(age) ks.jitter = getCacheJitter(age) next = ks.nextReloadDuration(age) ks.Unlock() } ks.Lock() ks.timer.Reset(next) ks.Unlock() } // nextReloadDuration would return the duration for the next rotation. If age is // 0 it will randomly rotate between 0-12 hours, but every time we call to Get // it will automatically rotate. func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { n := rand.Int63n(int64(ks.jitter)) age -= time.Duration(n) return abs(age) } func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) { var keys jose.JSONWebKeySet resp, err := http.Get(uri) if err != nil { return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri) } defer resp.Body.Close() if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { return keys, 0, errors.Wrapf(err, "error reading %s", uri) } return keys, getCacheAge(resp.Header.Get("cache-control")), nil } func getCacheAge(cacheControl string) time.Duration { age := defaultCacheAge if len(cacheControl) > 0 { match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1) if len(match) > 0 { if len(match[0]) == 2 { maxAge := match[0][1] maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64) if err != nil { return defaultCacheAge } age = time.Duration(maxAgeInt) * time.Second } } } return age } func getCacheJitter(age time.Duration) time.Duration { switch { case age > time.Hour: return defaultCacheJitter case age == 0: // Avoids a 0 jitter. The duration is not important as it will rotate // automatically on each Get request. return defaultCacheJitter default: return age / 3 } } func getExpirationTime(age time.Duration) time.Time { return time.Now().Truncate(time.Second).Add(age) } // abs returns the absolute value of n. func abs(n time.Duration) time.Duration { if n < 0 { return -n } return n }