diff --git a/authority/authority.go b/authority/authority.go index 5a0cf1ab..4ff32150 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/x509util" ) @@ -16,18 +17,15 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { - config *Config - rootX509Certs []*x509.Certificate - intermediateIdentity *x509util.Identity - validateOnce bool - certificates *sync.Map - ottMap *sync.Map - startTime time.Time - provisionerIDIndex *sync.Map - encryptedKeyIndex *sync.Map - provisionerKeySetIndex *sync.Map - sortedProvisioners provisionerSlice - audiences []string + config *Config + rootX509Certs []*x509.Certificate + intermediateIdentity *x509util.Identity + validateOnce bool + certificates *sync.Map + ottMap *sync.Map + startTime time.Time + provisioners *provisioner.Collection + audiences []string // Do not re-initialize initOnce bool } @@ -39,15 +37,6 @@ func New(config *Config) (*Authority, error) { return nil, err } - // Get sorted provisioners - var sorted provisionerSlice - if config.AuthorityConfig != nil { - sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners) - if err != nil { - return nil, err - } - } - // Define audiences: legacy + possible urls without the ports. // The CA might have proxies in front so we cannot rely on the port. audiences := []string{legacyAuthority} @@ -56,14 +45,11 @@ func New(config *Config) (*Authority, error) { } var a = &Authority{ - config: config, - certificates: new(sync.Map), - ottMap: new(sync.Map), - provisionerIDIndex: new(sync.Map), - encryptedKeyIndex: new(sync.Map), - provisionerKeySetIndex: new(sync.Map), - sortedProvisioners: sorted, - audiences: audiences, + config: config, + certificates: new(sync.Map), + ottMap: new(sync.Map), + provisioners: provisioner.NewCollection(audiences), + audiences: audiences, } if err := a.init(); err != nil { return nil, err @@ -120,10 +106,10 @@ func (a *Authority) init() error { } } + // Store all the provisioners for _, p := range a.config.AuthorityConfig.Provisioners { - a.provisionerIDIndex.Store(p.ID(), p) - if len(p.EncryptedKey) != 0 { - a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey) + if err := a.provisioners.Store(p); err != nil { + return err } } diff --git a/authority/provisioners.go b/authority/provisioners.go index 85713b7e..dcea9121 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1,115 +1,25 @@ package authority import ( - "crypto/sha1" - "encoding/binary" - "encoding/hex" - "encoding/json" - "fmt" - "math" "net/http" - "sort" - "strings" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" ) -// DefaultProvisionersLimit is the default limit for listing provisioners. -const DefaultProvisionersLimit = 20 - -// DefaultProvisionersMax is the maximum limit for listing provisioners. -const DefaultProvisionersMax = 100 - // GetEncryptedKey returns the JWE key corresponding to the given kid argument. func (a *Authority) GetEncryptedKey(kid string) (string, error) { - val, ok := a.encryptedKeyIndex.Load(kid) + key, ok := a.provisioners.LoadEncryptedKey(kid) if !ok { return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid), http.StatusNotFound, context{}} } - - key, ok := val.(string) - if !ok { - return "", &apiError{errors.Errorf("stored value is not a string"), - http.StatusInternalServerError, context{}} - } return key, nil } // GetProvisioners returns a map listing each provisioner and the JWK Key Set // with their public keys. -func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) { - provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit) +func (a *Authority) GetProvisioners(cursor string, limit int) ([]*provisioner.Provisioner, string, error) { + provisioners, nextCursor := a.provisioners.Find(cursor, limit) return provisioners, nextCursor, nil } - -type uidProvisioner struct { - provisioner *Provisioner - uid string -} - -func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) { - if len(provisioners) > math.MaxInt32 { - return nil, errors.New("too many provisioners") - } - - var slice provisionerSlice - bi := make([]byte, 4) - for i, p := range provisioners { - sum, err := provisionerSum(p) - if err != nil { - return nil, err - } - // Use the first 4 bytes (32bit) of the sum to insert the order - // Using big endian format to get the strings sorted: - // 0x00000000, 0x00000001, 0x00000002, ... - binary.BigEndian.PutUint32(bi, uint32(i)) - sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] - bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0 - slice = append(slice, uidProvisioner{ - provisioner: p, - uid: hex.EncodeToString(sum), - }) - } - sort.Sort(slice) - return slice, nil -} - -type provisionerSlice []uidProvisioner - -func (p provisionerSlice) Len() int { return len(p) } -func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } -func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } - -func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) { - switch { - case limit <= 0: - limit = DefaultProvisionersLimit - case limit > DefaultProvisionersMax: - limit = DefaultProvisionersMax - } - - n := len(p) - cursor = fmt.Sprintf("%040s", cursor) - i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor }) - - var slice []*Provisioner - for ; i < n && len(slice) < limit; i++ { - slice = append(slice, p[i].provisioner) - } - if i < n { - return slice, strings.TrimLeft(p[i].uid, "0") - } - return slice, "" -} - -// provisionerSum returns the SHA1 of the json representation of the -// provisioner. From this we will create the unique and sorted id. -func provisionerSum(p *Provisioner) ([]byte, error) { - b, err := json.Marshal(p.Key) - if err != nil { - return nil, errors.Wrap(err, "error marshalling provisioner") - } - sum := sha1.Sum(b) - return sum[:], nil -}