diff --git a/authority/authority.go b/authority/authority.go index f9ec6fbf..790ab6ad 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -28,6 +28,7 @@ type Authority struct { provisionerIDIndex *sync.Map encryptedKeyIndex *sync.Map provisionerKeySetIndex *sync.Map + sortedProvisioners provisionerSlice audiences []string // Do not re-initialize initOnce bool @@ -35,9 +36,31 @@ type Authority struct { // New creates and initiates a new Authority type. func New(config *Config) (*Authority, error) { - if err := config.Validate(); err != nil { + err := config.Validate() + if err != nil { return nil, err } + + // Get sorted provisioners + var sorted provisionerSlice + if config.AuthorityConfig != nil { + sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners) + } + + // Define audiences: legacy + possible urls + _, port, err := net.SplitHostPort(config.Address) + if err != nil { + return nil, errors.Wrapf(err, "error parsing %s", config.Address) + } + audiences := []string{legacyAuthority} + for _, name := range config.DNSNames { + if port == "443" { + audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name)) + } + audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port)) + + } + var a = &Authority{ config: config, certificates: new(sync.Map), @@ -45,6 +68,8 @@ func New(config *Config) (*Authority, error) { provisionerIDIndex: new(sync.Map), encryptedKeyIndex: new(sync.Map), provisionerKeySetIndex: new(sync.Map), + sortedProvisioners: sorted, + audiences: audiences, } if err := a.init(); err != nil { return nil, err @@ -70,21 +95,6 @@ func (a *Authority) init() error { sum := sha256.Sum256(a.rootX509Crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) - // Define audiences: legacy + possible urls - _, port, err := net.SplitHostPort(a.config.Address) - if err != nil { - return errors.Wrapf(err, "error parsing %s", a.config.Address) - } - audiences := []string{legacyAuthority} - for _, name := range a.config.DNSNames { - if port == "443" { - audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name)) - } - audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port)) - - } - a.audiences = audiences - // Decrypt and load intermediate public / private key pair. if len(a.config.Password) > 0 { a.intermediateIdentity, err = x509util.LoadIdentityFromDisk( diff --git a/authority/provisioners.go b/authority/provisioners.go index 89f6b253..098cba4c 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1,11 +1,25 @@ package authority import ( + "crypto/sha1" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "math" "net/http" + "sort" + "strings" "github.com/pkg/errors" ) +// DefaultProvisionersLimit is the default limit for listing provisioners. +const DefaultProvisionersLimit = 20 + +// DefaultProvisionersMax is the maximum limit for listing provisioners. +const DefaultProvisionersMax = 100 + // GetEncryptedKey returns the JWE key corresponding to the given kid argument. func (a *Authority) GetEncryptedKey(kid string) (string, error) { val, ok := a.encryptedKeyIndex.Load(kid) @@ -27,3 +41,74 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) { func (a *Authority) GetProvisioners() ([]*Provisioner, error) { return a.config.AuthorityConfig.Provisioners, nil } + +type uidProvisioner struct { + provisioner *provisioner.Provisioner + uid string +} + +func newSortedProvisioners(provisioners []*provisioner.Provisioner) (provisionerSlice, error) { + if len(provisioners) > math.MaxUint32 { + return nil, errors.New("too many provisioners") + } + + var slice provisionerSlice + bi := make([]byte, 4) + for i, p := range provisioners { + sum, err := provisionerSum(p) + if err != nil { + return nil, err + } + // Use the first 4 bytes (32bit) of the sum to insert the order + // Using big endian format to get the strings sorted: + // 0x00000000, 0x00000001, 0x00000002, ... + binary.BigEndian.PutUint32(bi, uint32(i)) + sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] + bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0 + slice = append(slice, uidProvisioner{ + provisioner: p, + uid: hex.EncodeToString(sum), + }) + } + sort.Sort(slice) + return slice, nil +} + +type provisionerSlice []uidProvisioner + +func (p provisionerSlice) Len() int { return len(p) } +func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } +func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +func (p provisionerSlice) Find(cursor string, limit int) ([]*provisioner.Provisioner, string) { + switch { + case limit <= 0: + limit = DefaultProvisionersLimit + case limit > DefaultProvisionersMax: + limit = DefaultProvisionersMax + } + + n := len(p) + cursor = fmt.Sprintf("%040s", cursor) + i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor }) + + var slice []*provisioner.Provisioner + for ; i < n && len(slice) < limit; i++ { + slice = append(slice, p[i].provisioner) + } + if i < n { + return slice, strings.TrimLeft(p[i].uid, "0") + } + return slice, "" +} + +// provisionerSum returns the SHA1 of the json representation of the +// provisioner. From this we will create the unique and sorted id. +func provisionerSum(p *provisioner.Provisioner) ([]byte, error) { + b, err := json.Marshal(p.Key) + if err != nil { + return nil, errors.Wrap(err, "error marshalling provisioner") + } + sum := sha1.Sum(b) + return sum[:], nil +} diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 973a59f6..688e48e6 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,11 +1,17 @@ package authority import ( + "encoding/json" "net/http" + "reflect" + "strings" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/ca-component/provisioner" + "github.com/smallstep/cli/crypto/randutil" + "github.com/smallstep/cli/jose" ) func TestGetEncryptedKey(t *testing.T) { @@ -120,3 +126,102 @@ func TestGetProvisioners(t *testing.T) { }) } } + +func generateProvisioner(t *testing.T) *provisioner.Provisioner { + issuer, err := randutil.Alphanumeric(10) + assert.FatalError(t, err) + // Create a new JWK + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + // Encrypt JWK + salt, err := randutil.Salt(jose.PBKDF2SaltSize) + assert.FatalError(t, err) + b, err := json.Marshal(jwk) + assert.FatalError(t, err) + recipient := jose.Recipient{ + Algorithm: jose.PBES2_HS256_A128KW, + Key: []byte("password"), + PBES2Count: jose.PBKDF2Iterations, + PBES2Salt: salt, + } + opts := new(jose.EncrypterOptions) + opts.WithContentType(jose.ContentType("jwk+json")) + encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) + assert.FatalError(t, err) + jwe, err := encrypter.Encrypt(b) + assert.FatalError(t, err) + // get public and encrypted keys + public := jwk.Public() + encrypted, err := jwe.CompactSerialize() + assert.FatalError(t, err) + return &provisioner.Provisioner{ + Issuer: issuer, + Type: "JWT", + Key: &public, + EncryptedKey: encrypted, + } +} + +func Test_newSortedProvisioners(t *testing.T) { + provisioners := make([]*provisioner.Provisioner, 20) + for i := range provisioners { + provisioners[i] = generateProvisioner(t) + } + + ps, err := newSortedProvisioners(provisioners) + assert.FatalError(t, err) + prev := "" + for i, p := range ps { + if p.uid < prev { + t.Errorf("%s should be less that %s", p.uid, prev) + } + if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID { + t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID) + } + prev = p.uid + } +} + +func Test_provisionerSlice_Find(t *testing.T) { + trim := func(s string) string { + return strings.TrimLeft(s, "0") + } + provisioners := make([]*provisioner.Provisioner, 20) + for i := range provisioners { + provisioners[i] = generateProvisioner(t) + } + ps, err := newSortedProvisioners(provisioners) + assert.FatalError(t, err) + + type args struct { + cursor string + limit int + } + tests := []struct { + name string + p provisionerSlice + args args + want []*provisioner.Provisioner + want1 string + }{ + {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""}, + {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""}, + {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)}, + {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""}, + {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)}, + {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)}, + {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""}, + {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1) + } + }) + } +}