commit
095ab891e7
@ -1,117 +0,0 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
x509 "github.com/smallstep/cli/pkg/x509"
|
||||
)
|
||||
|
||||
// certClaim interface is implemented by types used to validate specific claims in a
|
||||
// certificate request.
|
||||
type certClaim interface {
|
||||
Valid(crt *x509.Certificate) error
|
||||
}
|
||||
|
||||
// ValidateClaims returns nil if all the claims are validated, it will return
|
||||
// the first error if a claim fails.
|
||||
func validateClaims(crt *x509.Certificate, claims []certClaim) (err error) {
|
||||
for _, c := range claims {
|
||||
if err = c.Valid(crt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// commonNameClaim validates the common name of a certificate request.
|
||||
type commonNameClaim struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// Valid checks that certificate request common name matches the one configured.
|
||||
func (c *commonNameClaim) Valid(crt *x509.Certificate) error {
|
||||
if crt.Subject.CommonName == "" {
|
||||
return errors.New("common name cannot be empty")
|
||||
}
|
||||
if crt.Subject.CommonName != c.name {
|
||||
return errors.Errorf("common name claim failed - got %s, want %s", crt.Subject.CommonName, c.name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dnsNamesClaim struct {
|
||||
names []string
|
||||
}
|
||||
|
||||
// Valid checks that certificate request DNS Names match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (c *dnsNamesClaim) Valid(crt *x509.Certificate) error {
|
||||
tokMap := make(map[string]int)
|
||||
for _, e := range c.names {
|
||||
tokMap[e] = 1
|
||||
}
|
||||
crtMap := make(map[string]int)
|
||||
for _, e := range crt.DNSNames {
|
||||
crtMap[e] = 1
|
||||
}
|
||||
if !reflect.DeepEqual(tokMap, crtMap) {
|
||||
return errors.Errorf("DNS names claim failed - got %s, want %s", crt.DNSNames, c.names)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipAddressesClaim struct {
|
||||
ips []net.IP
|
||||
}
|
||||
|
||||
// Valid checks that certificate request IP Addresses match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (c *ipAddressesClaim) Valid(crt *x509.Certificate) error {
|
||||
tokMap := make(map[string]int)
|
||||
for _, e := range c.ips {
|
||||
tokMap[e.String()] = 1
|
||||
}
|
||||
crtMap := make(map[string]int)
|
||||
for _, e := range crt.IPAddresses {
|
||||
crtMap[e.String()] = 1
|
||||
}
|
||||
if !reflect.DeepEqual(tokMap, crtMap) {
|
||||
return errors.Errorf("IP Addresses claim failed - got %v, want %v", crt.IPAddresses, c.ips)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// certTemporalClaim validates the certificate temporal validity settings.
|
||||
type certTemporalClaim struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}
|
||||
|
||||
// Validate validates the certificate temporal validity settings.
|
||||
func (ctc *certTemporalClaim) Valid(crt *x509.Certificate) error {
|
||||
var (
|
||||
na = crt.NotAfter
|
||||
nb = crt.NotBefore
|
||||
d = na.Sub(nb)
|
||||
now = time.Now()
|
||||
)
|
||||
|
||||
if na.Before(now) {
|
||||
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
||||
}
|
||||
if na.Before(nb) {
|
||||
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
||||
}
|
||||
if d < ctc.min {
|
||||
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
||||
d, ctc.min)
|
||||
}
|
||||
if d > ctc.max {
|
||||
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
||||
d, ctc.max)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,132 +0,0 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
x509 "github.com/smallstep/cli/pkg/x509"
|
||||
)
|
||||
|
||||
func TestCommonNameClaim_Valid(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
cnc certClaim
|
||||
crt *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
"empty-common-name": {
|
||||
cnc: &commonNameClaim{name: "foo"},
|
||||
crt: &x509.Certificate{},
|
||||
err: errors.New("common name cannot be empty"),
|
||||
},
|
||||
"wrong-common-name": {
|
||||
cnc: &commonNameClaim{name: "foo"},
|
||||
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}},
|
||||
err: errors.New("common name claim failed - got bar, want foo"),
|
||||
},
|
||||
"ok": {
|
||||
cnc: &commonNameClaim{name: "foo"},
|
||||
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.cnc.Valid(tc.crt)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAddressesClaim_Valid(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
iac certClaim
|
||||
crt *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
"unexpected-ip-in-crt": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
|
||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1 1.1.1.1], want [127.0.0.1]"),
|
||||
},
|
||||
"missing-ip-in-crt": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want [127.0.0.1 1.1.1.1]"),
|
||||
},
|
||||
"invalid-matcher-nonempty-ips": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want []"),
|
||||
},
|
||||
"ok": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
},
|
||||
"ok-multiple-identical-ip-entries": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1")}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.iac.Valid(tc.crt)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSNamesClaim_Valid(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
dnc certClaim
|
||||
crt *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
"unexpected-dns-name-in-crt": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"foo", "bar"}},
|
||||
err: errors.New("DNS names claim failed - got [foo bar], want [foo]"),
|
||||
},
|
||||
"ok": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"bar", "foo"}},
|
||||
},
|
||||
"missing-dns-name-in-crt": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"foo"}},
|
||||
err: errors.New("DNS names claim failed - got [foo], want [foo bar]"),
|
||||
},
|
||||
"ok-multiple-identical-dns-entries": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"foo", "foo", "foo"}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.dnc.Valid(tc.crt)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,212 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
const DefaultProvisionersLimit = 20
|
||||
|
||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||
const DefaultProvisionersMax = 100
|
||||
|
||||
type uidProvisioner struct {
|
||||
provisioner Interface
|
||||
uid string
|
||||
}
|
||||
|
||||
type provisionerSlice []uidProvisioner
|
||||
|
||||
func (p provisionerSlice) Len() int { return len(p) }
|
||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// Collection is a memory map of provisioners.
|
||||
type Collection struct {
|
||||
byID *sync.Map
|
||||
byKey *sync.Map
|
||||
sorted provisionerSlice
|
||||
audiences []string
|
||||
}
|
||||
|
||||
// NewCollection initializes a collection of provisioners. The given list of
|
||||
// audiences are the audiences used by the JWT provisioner.
|
||||
func NewCollection(audiences []string) *Collection {
|
||||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
audiences: audiences,
|
||||
}
|
||||
}
|
||||
|
||||
// Load a provisioner by the ID.
|
||||
func (c *Collection) Load(id string) (Interface, bool) {
|
||||
return loadProvisioner(c.byID, id)
|
||||
}
|
||||
|
||||
// LoadByToken parses the token claims and loads the provisioner associated.
|
||||
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
|
||||
// match with server audiences
|
||||
if matchesAudience(claims.Audience, c.audiences) {
|
||||
// If matches with stored audiences it will be a JWT token (default), and
|
||||
// the id would be <issuer>:<kid>.
|
||||
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
// The ID will be just the clientID stored in azp or aud.
|
||||
var payload openIDPayload
|
||||
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
// audience is required
|
||||
if len(payload.Audience) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
if len(payload.AuthorizedParty) > 0 {
|
||||
return c.Load(payload.AuthorizedParty)
|
||||
}
|
||||
return c.Load(payload.Audience[0])
|
||||
}
|
||||
|
||||
// LoadByCertificate looks for the provisioner extension and extracts the
|
||||
// proper id to load the provisioner.
|
||||
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
|
||||
for _, e := range cert.Extensions {
|
||||
if e.Id.Equal(stepOIDProvisioner) {
|
||||
var provisioner stepProvisionerASN1
|
||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if provisioner.Type == int(TypeJWK) {
|
||||
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
|
||||
}
|
||||
return c.Load(string(provisioner.CredentialID))
|
||||
}
|
||||
}
|
||||
|
||||
// Default to noop provisioner if an extension is not found. This allows to
|
||||
// accept a renewal of a cert without the provisioner extension.
|
||||
return &noop{}, true
|
||||
}
|
||||
|
||||
// LoadEncryptedKey returns an encrypted key by indexed by KeyID. At this moment
|
||||
// only JWK encrypted keys are indexed by KeyID.
|
||||
func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
||||
p, ok := loadProvisioner(c.byKey, keyID)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
_, key, ok := p.GetEncryptedKey()
|
||||
return key, ok
|
||||
}
|
||||
|
||||
// Store adds a provisioner to the collection and enforces the uniqueness of
|
||||
// provisioner IDs.
|
||||
func (c *Collection) Store(p Interface) error {
|
||||
// Store provisioner always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
|
||||
return errors.New("cannot add multiple provisioners with the same id")
|
||||
}
|
||||
|
||||
// Store provisioner in byKey if EncryptedKey is defined.
|
||||
if kid, _, ok := p.GetEncryptedKey(); ok {
|
||||
c.byKey.Store(kid, p)
|
||||
}
|
||||
|
||||
// Store sorted provisioners.
|
||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
bi := make([]byte, 4)
|
||||
sum := provisionerSum(p)
|
||||
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
c.sorted = append(c.sorted, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
sort.Sort(c.sorted)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find implements pagination on a list of sorted provisioners.
|
||||
func (c *Collection) Find(cursor string, limit int) (List, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultProvisionersLimit
|
||||
case limit > DefaultProvisionersMax:
|
||||
limit = DefaultProvisionersMax
|
||||
}
|
||||
|
||||
n := c.sorted.Len()
|
||||
cursor = fmt.Sprintf("%040s", cursor)
|
||||
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
|
||||
|
||||
slice := List{}
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, c.sorted[i].provisioner)
|
||||
}
|
||||
|
||||
if i < n {
|
||||
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
func loadProvisioner(m *sync.Map, key string) (Interface, bool) {
|
||||
i, ok := m.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
p, ok := i.(Interface)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return p, true
|
||||
}
|
||||
|
||||
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
|
||||
// create the unique and sorted id.
|
||||
func provisionerSum(p Interface) []byte {
|
||||
sum := sha1.Sum([]byte(p.GetID()))
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
// matchesAudience returns true if A and B share at least one element.
|
||||
func matchesAudience(as, bs []string) bool {
|
||||
if len(bs) == 0 || len(as) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, b := range bs {
|
||||
for _, a := range as {
|
||||
if b == a || stripPort(a) == stripPort(b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||
// produces errors it will just return the passed argument.
|
||||
func stripPort(rawurl string) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return rawurl
|
||||
}
|
||||
u.Host = u.Hostname()
|
||||
return u.String()
|
||||
}
|
@ -0,0 +1,390 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func TestCollection_Load(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p.GetID(), p)
|
||||
byID.Store("string", "a-string")
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
}
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", fields{byID}, args{p.GetID()}, p, true},
|
||||
{"fail", fields{byID}, args{"fail"}, nil, false},
|
||||
{"invalid", fields{byID}, args{"string"}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
}
|
||||
got, got1 := c.Load(tt.args.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByToken(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
byID.Store(p3.GetID(), p3)
|
||||
byID.Store("string", "a-string")
|
||||
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk)
|
||||
assert.FatalError(t, err)
|
||||
t1, c1, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk)
|
||||
assert.FatalError(t, err)
|
||||
t2, c2, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t3, c3, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t4, c4, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
}
|
||||
type args struct {
|
||||
token *jose.JSONWebToken
|
||||
claims *jose.Claims
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
|
||||
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
|
||||
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
|
||||
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByCertificate(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
|
||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||
assert.FatalError(t, err)
|
||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||
assert.FatalError(t, err)
|
||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ok1Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok1Ext},
|
||||
}
|
||||
ok2Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok2Ext},
|
||||
}
|
||||
notFoundCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{notFoundExt},
|
||||
}
|
||||
badCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{
|
||||
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||
},
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
}
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByCertificate(tt.args.cert)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadEncryptedKey(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p1))
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p2))
|
||||
|
||||
// Add oidc in byKey.
|
||||
// It should not happen.
|
||||
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
|
||||
c.byKey.Store(p2KeyID, p2)
|
||||
|
||||
type args struct {
|
||||
keyID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
|
||||
{"oidc", args{p2KeyID}, "", false},
|
||||
{"notFound", args{"not-found"}, "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
|
||||
if got != tt.want {
|
||||
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_Store(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
p Interface
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", args{p1}, false},
|
||||
{"ok2", args{p2}, false},
|
||||
{"fail1", args{p1}, true},
|
||||
{"fail2", args{p2}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_Find(t *testing.T) {
|
||||
c, err := generateCollection(10, 10)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
trim := func(s string) string {
|
||||
return strings.TrimLeft(s, "0")
|
||||
}
|
||||
toList := func(ps provisionerSlice) List {
|
||||
l := List{}
|
||||
for _, p := range ps {
|
||||
l = append(l, p.provisioner)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cursor string
|
||||
limit int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want List
|
||||
want1 string
|
||||
}{
|
||||
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
|
||||
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
|
||||
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
|
||||
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
|
||||
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
|
||||
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
|
||||
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
|
||||
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_matchesAudience(t *testing.T) {
|
||||
type matchesTest struct {
|
||||
a, b []string
|
||||
exp bool
|
||||
}
|
||||
tests := map[string]matchesTest{
|
||||
"false arg1 empty": {
|
||||
a: []string{},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: false,
|
||||
},
|
||||
"false arg2 empty": {
|
||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{},
|
||||
exp: false,
|
||||
},
|
||||
"false arg1,arg2 empty": {
|
||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"step-gateway", "step-cli"},
|
||||
exp: false,
|
||||
},
|
||||
"false": {
|
||||
a: []string{"step-gateway", "step-cli"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: false,
|
||||
},
|
||||
"true": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsA": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsB": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsAB": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
|
||||
exp: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_stripPort(t *testing.T) {
|
||||
type args struct {
|
||||
rawurl string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
|
||||
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
|
||||
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := stripPort(tt.args.rawurl); got != tt.want {
|
||||
t.Errorf("stripPort() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.Duration.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var (
|
||||
s string
|
||||
_d time.Duration
|
||||
)
|
||||
if d == nil {
|
||||
return errors.New("duration cannot be nil")
|
||||
}
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshaling %s", data)
|
||||
}
|
||||
if _d, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
d.Duration = _d
|
||||
return
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDuration_UnmarshalJSON(t *testing.T) {
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
d *Duration
|
||||
args args
|
||||
want *Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
|
||||
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
|
||||
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
|
||||
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
|
||||
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
|
||||
{"nil", nil, args{nil}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(tt.d, tt.want) {
|
||||
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
d *Duration
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.d.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// jwtPayload extends jwt.Claims with step attributes.
|
||||
type jwtPayload struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans,omitempty"`
|
||||
}
|
||||
|
||||
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
||||
// signature requests.
|
||||
type JWK struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
EncryptedKey string `json:"encryptedKey,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
audiences []string
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
// should uniquely identify any JWK provisioner.
|
||||
func (p *JWK) GetID() string {
|
||||
return p.Name + ":" + p.Key.KeyID
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
func (p *JWK) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
||||
// GetType returns the type of provisioner.
|
||||
func (p *JWK) GetType() Type {
|
||||
return TypeJWK
|
||||
}
|
||||
|
||||
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
|
||||
func (p *JWK) GetEncryptedKey() (string, string, bool) {
|
||||
return p.Key.KeyID, p.EncryptedKey, len(p.EncryptedKey) > 0
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of a JWK type.
|
||||
func (p *JWK) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
case p.Name == "":
|
||||
return errors.New("provisioner name cannot be empty")
|
||||
case p.Key == nil:
|
||||
return errors.New("provisioner key cannot be empty")
|
||||
}
|
||||
p.Claims, err = p.Claims.Init(&config.Claims)
|
||||
p.audiences = config.Audiences
|
||||
return err
|
||||
}
|
||||
|
||||
// Authorize validates the given token.
|
||||
func (p *JWK) Authorize(token string) ([]SignOption, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
|
||||
var claims jwtPayload
|
||||
if err = jwt.Claims(p.Key, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
// more than a few minutes.
|
||||
if err = claims.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: p.Name,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, p.audiences) {
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("token subject cannot be empty")
|
||||
}
|
||||
|
||||
// NOTE: This is for backwards compatibility with older versions of cli
|
||||
// and certificates. Older versions added the token subject as the only SAN
|
||||
// in a CSR by default.
|
||||
if len(claims.SANs) == 0 {
|
||||
claims.SANs = []string{claims.Subject}
|
||||
}
|
||||
|
||||
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
||||
return []SignOption{
|
||||
commonNameValidator(claims.Subject),
|
||||
dnsNamesValidator(dnsNames),
|
||||
ipAddressesValidator(ips),
|
||||
profileDefaultDuration(p.Claims.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
||||
newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||
func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
if p.Claims.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *JWK) AuthorizeRevoke(token string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
@ -0,0 +1,256 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = Claims{
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func TestJWK_Getters(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
if got := p.GetID(); got != p.Name+":"+p.Key.KeyID {
|
||||
t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("JWK.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeJWK {
|
||||
t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false {
|
||||
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, p.Key.KeyID, p.EncryptedKey, true)
|
||||
}
|
||||
p.EncryptedKey = ""
|
||||
kid, key, ok = p.GetEncryptedKey()
|
||||
if kid != p.Key.KeyID || key != "" || ok == true {
|
||||
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, p.Key.KeyID, "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_Init(t *testing.T) {
|
||||
type ProvisionerValidateTest struct {
|
||||
p *JWK
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{
|
||||
Type: "JWK",
|
||||
},
|
||||
err: errors.New("provisioner name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo"},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar"},
|
||||
err: errors.New("provisioner key cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}
|
||||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.p.Init(config)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_Authorize(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
key1, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
key2, err := decryptJSONWebKey(p2.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
|
||||
assert.FatalError(t, err)
|
||||
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2)
|
||||
assert.FatalError(t, err)
|
||||
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Invalid tokens
|
||||
parts := strings.Split(t1, ".")
|
||||
key3, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
// missing key
|
||||
failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3)
|
||||
assert.FatalError(t, err)
|
||||
// invalid token
|
||||
failTok := "foo." + parts[1] + "." + parts[2]
|
||||
// invalid claims
|
||||
failClaims := parts[0] + ".foo." + parts[1]
|
||||
// invalid issuer
|
||||
failIss, err := generateSimpleToken("foobar", testAudiences[0], key1)
|
||||
assert.FatalError(t, err)
|
||||
// invalid audience
|
||||
failAud, err := generateSimpleToken(p1.Name, "foobar", key1)
|
||||
assert.FatalError(t, err)
|
||||
// invalid signature
|
||||
failSig := t1[0 : len(t1)-2]
|
||||
// no subject
|
||||
failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1)
|
||||
assert.FatalError(t, err)
|
||||
// expired
|
||||
failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
|
||||
assert.FatalError(t, err)
|
||||
// not before
|
||||
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Remove encrypted key for p2
|
||||
p2.EncryptedKey = ""
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, false},
|
||||
{"ok-no-encrypted-key", p2, args{t2}, false},
|
||||
{"ok-no-sans", p1, args{t3}, false},
|
||||
{"fail-key", p1, args{failKey}, true},
|
||||
{"fail-token", p1, args{failTok}, true},
|
||||
{"fail-claims", p1, args{failClaims}, true},
|
||||
{"fail-issuer", p1, args{failIss}, true},
|
||||
{"fail-audience", p1, args{failAud}, true},
|
||||
{"fail-signature", p1, args{failSig}, true},
|
||||
{"fail-subject", p1, args{failSub}, true},
|
||||
{"fail-expired", p1, args{failExp}, true},
|
||||
{"fail-not-before", p1, args{failNbf}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.prov.Authorize(tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NotNil(t, got)
|
||||
assert.Len(t, 6, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{
|
||||
globalClaims: &globalProvisionerClaims,
|
||||
DisableRenewal: &disable,
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeRevoke(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
key1, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"disabled", p1, args{t1}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,135 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCacheAge = 12 * time.Hour
|
||||
defaultCacheJitter = 1 * time.Hour
|
||||
)
|
||||
|
||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||
|
||||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
uri string
|
||||
keySet jose.JSONWebKeySet
|
||||
timer *time.Timer
|
||||
expiry time.Time
|
||||
jitter time.Duration
|
||||
}
|
||||
|
||||
func newKeyStore(uri string) (*keyStore, error) {
|
||||
keys, age, err := getKeysFromJWKsURI(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ks := &keyStore{
|
||||
uri: uri,
|
||||
keySet: keys,
|
||||
expiry: getExpirationTime(age),
|
||||
jitter: getCacheJitter(age),
|
||||
}
|
||||
next := ks.nextReloadDuration(age)
|
||||
ks.timer = time.AfterFunc(next, ks.reload)
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (ks *keyStore) Close() {
|
||||
ks.timer.Stop()
|
||||
}
|
||||
|
||||
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||
ks.RLock()
|
||||
// Force reload if expiration has passed
|
||||
if time.Now().After(ks.expiry) {
|
||||
ks.RUnlock()
|
||||
ks.reload()
|
||||
ks.RLock()
|
||||
}
|
||||
keys = ks.keySet.Key(kid)
|
||||
ks.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (ks *keyStore) reload() {
|
||||
var next time.Duration
|
||||
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||
if err != nil {
|
||||
next = ks.nextReloadDuration(ks.jitter / 2)
|
||||
} else {
|
||||
ks.Lock()
|
||||
ks.keySet = keys
|
||||
ks.expiry = getExpirationTime(age)
|
||||
ks.jitter = getCacheJitter(age)
|
||||
next = ks.nextReloadDuration(age)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
ks.Lock()
|
||||
ks.timer.Reset(next)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||
n := rand.Int63n(int64(ks.jitter))
|
||||
age -= time.Duration(n)
|
||||
if age < 0 {
|
||||
age = 0
|
||||
}
|
||||
return age
|
||||
}
|
||||
|
||||
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||
var keys jose.JSONWebKeySet
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
|
||||
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
|
||||
}
|
||||
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
||||
}
|
||||
|
||||
func getCacheAge(cacheControl string) time.Duration {
|
||||
age := defaultCacheAge
|
||||
if len(cacheControl) > 0 {
|
||||
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
|
||||
if len(match) > 0 {
|
||||
if len(match[0]) == 2 {
|
||||
maxAge := match[0][1]
|
||||
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
|
||||
if err != nil {
|
||||
return defaultCacheAge
|
||||
}
|
||||
age = time.Duration(maxAgeInt) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
return age
|
||||
}
|
||||
|
||||
func getCacheJitter(age time.Duration) time.Duration {
|
||||
switch {
|
||||
case age > time.Hour:
|
||||
return defaultCacheJitter
|
||||
default:
|
||||
return age / 3
|
||||
}
|
||||
}
|
||||
|
||||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Truncate(time.Second).Add(age)
|
||||
}
|
@ -0,0 +1,121 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func Test_newKeyStore(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
uri string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want jose.JSONWebKeySet
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{srv.URL}, ks.keySet, false},
|
||||
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newKeyStore(tt.args.uri)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
if !reflect.DeepEqual(got.keySet, tt.want) {
|
||||
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keyStore(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
ks, err := newKeyStore(srv.URL + "/random")
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
ks.RLock()
|
||||
keySet1 := ks.keySet
|
||||
ks.RUnlock()
|
||||
// Check contents
|
||||
assert.Len(t, 2, keySet1.Keys)
|
||||
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
|
||||
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Wait for rotation
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
ks.RLock()
|
||||
keySet2 := ks.keySet
|
||||
ks.RUnlock()
|
||||
if reflect.DeepEqual(keySet1, keySet2) {
|
||||
t.Error("keyStore did not rotated")
|
||||
}
|
||||
|
||||
// Check contents
|
||||
assert.Len(t, 2, keySet2.Keys)
|
||||
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
|
||||
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Check hits
|
||||
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||
assert.FatalError(t, err)
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
defer resp.Body.Close()
|
||||
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||
assert.FatalError(t, err)
|
||||
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||
}
|
||||
|
||||
func Test_keyStore_Get(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
kid string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
ks *keyStore
|
||||
args args
|
||||
wantKeys []jose.JSONWebKey
|
||||
}{
|
||||
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
|
||||
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
|
||||
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
|
||||
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
package provisioner
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
// noop provisioners is a provisioner that accepts anything.
|
||||
type noop struct{}
|
||||
|
||||
func (p *noop) GetID() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (p *noop) GetName() string {
|
||||
return "noop"
|
||||
}
|
||||
func (p *noop) GetType() Type {
|
||||
return noopType
|
||||
}
|
||||
|
||||
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func (p *noop) Init(config Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *noop) Authorize(token string) ([]SignOption, error) {
|
||||
return []SignOption{}, nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeRevoke(token string) error {
|
||||
return nil
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func Test_noop(t *testing.T) {
|
||||
p := noop{}
|
||||
assert.Equals(t, "noop", p.GetID())
|
||||
assert.Equals(t, "noop", p.GetName())
|
||||
assert.Equals(t, noopType, p.GetType())
|
||||
assert.Equals(t, nil, p.Init(Config{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRenewal(&x509.Certificate{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRevoke("foo"))
|
||||
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
assert.Equals(t, "", kid)
|
||||
assert.Equals(t, "", key)
|
||||
assert.Equals(t, false, ok)
|
||||
|
||||
sigOptions, err := p.Authorize("foo")
|
||||
assert.Equals(t, []SignOption{}, sigOptions)
|
||||
assert.Equals(t, nil, err)
|
||||
}
|
@ -0,0 +1,243 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// openIDConfiguration contains the necessary properties in the
|
||||
// `/.well-known/openid-configuration` document.
|
||||
type openIDConfiguration struct {
|
||||
Issuer string `json:"issuer"`
|
||||
JWKSetURI string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
// Validate validates the values in a well-known OpenID configuration endpoint.
|
||||
func (c openIDConfiguration) Validate() error {
|
||||
switch {
|
||||
case c.Issuer == "":
|
||||
return errors.New("issuer cannot be empty")
|
||||
case c.JWKSetURI == "":
|
||||
return errors.New("jwks_uri cannot be empty")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// openIDPayload represents the fields on the id_token JWT payload.
|
||||
type openIDPayload struct {
|
||||
jose.Claims
|
||||
AtHash string `json:"at_hash"`
|
||||
AuthorizedParty string `json:"azp"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Hd string `json:"hd"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||
//
|
||||
// ClientSecret is mandatory, but it can be an empty string.
|
||||
type OIDC struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
ConfigurationEndpoint string `json:"configurationEndpoint"`
|
||||
Admins []string `json:"admins,omitempty"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
configuration openIDConfiguration
|
||||
keyStore *keyStore
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the given email is in the Admins whitelist, false
|
||||
// otherwise.
|
||||
func (o *OIDC) IsAdmin(email string) bool {
|
||||
email = sanitizeEmail(email)
|
||||
for _, e := range o.Admins {
|
||||
if email == sanitizeEmail(e) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeEmail(email string) string {
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
email = email[:i] + strings.ToLower(email[i:])
|
||||
}
|
||||
return email
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier, the OIDC provisioner the
|
||||
// uses the clientID for this.
|
||||
func (o *OIDC) GetID() string {
|
||||
return o.ClientID
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
func (o *OIDC) GetName() string {
|
||||
return o.Name
|
||||
}
|
||||
|
||||
// GetType returns the type of provisioner.
|
||||
func (o *OIDC) GetType() Type {
|
||||
return TypeOIDC
|
||||
}
|
||||
|
||||
// GetEncryptedKey is not available in an OIDC provisioner.
|
||||
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Init validates and initializes the OIDC provider.
|
||||
func (o *OIDC) Init(config Config) (err error) {
|
||||
switch {
|
||||
case o.Type == "":
|
||||
return errors.New("type cannot be empty")
|
||||
case o.Name == "":
|
||||
return errors.New("name cannot be empty")
|
||||
case o.ClientID == "":
|
||||
return errors.New("clientID cannot be empty")
|
||||
case o.ConfigurationEndpoint == "":
|
||||
return errors.New("configurationEndpoint cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
// Decode and validate openid-configuration endpoint
|
||||
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.configuration.Validate(); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
|
||||
}
|
||||
// Get JWK key set
|
||||
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePayload validates the given token payload.
|
||||
func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
|
||||
// than a few minutes.
|
||||
if err := p.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: o.configuration.Issuer,
|
||||
Audience: jose.Audience{o.ClientID},
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return errors.Wrap(err, "failed to validate payload")
|
||||
}
|
||||
|
||||
// Validate azp if present
|
||||
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||
return errors.New("failed to validate payload: invalid azp")
|
||||
}
|
||||
|
||||
// Enforce an email claim
|
||||
if p.Email == "" {
|
||||
return errors.New("failed to validate payload: email not found")
|
||||
}
|
||||
|
||||
// Validate domains (case-insensitive)
|
||||
if !o.IsAdmin(p.Email) && len(o.Domains) > 0 {
|
||||
email := sanitizeEmail(p.Email)
|
||||
var found bool
|
||||
for _, d := range o.Domains {
|
||||
if strings.HasSuffix(email, "@"+strings.ToLower(d)) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("failed to validate payload: email is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authorize validates the given token.
|
||||
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
|
||||
// Parse claims to get the kid
|
||||
var claims openIDPayload
|
||||
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
}
|
||||
|
||||
found := false
|
||||
kid := jwt.Headers[0].KeyID
|
||||
keys := o.keyStore.Get(kid)
|
||||
for _, key := range keys {
|
||||
if err := jwt.Claims(key, &claims); err == nil {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("cannot validate token")
|
||||
}
|
||||
|
||||
if err := o.ValidatePayload(claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Admins should be able to authorize any SAN
|
||||
if o.IsAdmin(claims.Email) {
|
||||
return []SignOption{
|
||||
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return []SignOption{
|
||||
emailOnlyIdentity(claims.Email),
|
||||
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||
func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
if o.Claims.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (o *OIDC) AuthorizeRevoke(token string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func getAndDecode(uri string, v interface{}) error {
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
|
||||
return errors.Wrapf(err, "error reading %s", uri)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,327 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func Test_openIDConfiguration_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Issuer string
|
||||
JWKSetURI string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"the-issuer", "the-jwks-uri"}, false},
|
||||
{"no-issuer", fields{"", "the-jwks-uri"}, true},
|
||||
{"no-jwks-uri", fields{"the-issuer", ""}, true},
|
||||
{"empty", fields{"", ""}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := openIDConfiguration{
|
||||
Issuer: tt.fields.Issuer,
|
||||
JWKSetURI: tt.fields.JWKSetURI,
|
||||
}
|
||||
if err := c.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_Getters(t *testing.T) {
|
||||
p, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
if got := p.GetID(); got != p.ClientID {
|
||||
t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeOIDC {
|
||||
t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_Init(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Name string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ConfigurationEndpoint string
|
||||
Claims *Claims
|
||||
Admins []string
|
||||
Domains []string
|
||||
}
|
||||
type args struct {
|
||||
config Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false},
|
||||
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false},
|
||||
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
|
||||
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &OIDC{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
ClientID: tt.fields.ClientID,
|
||||
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
|
||||
Claims: tt.fields.Claims,
|
||||
Admins: tt.fields.Admins,
|
||||
}
|
||||
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr == false {
|
||||
assert.Len(t, 2, p.keyStore.keySet.Keys)
|
||||
assert.Equals(t, openIDConfiguration{
|
||||
Issuer: "the-issuer",
|
||||
JWKSetURI: srv.URL + "/jwks_uri",
|
||||
}, p.configuration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_Authorize(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
// Admin + Domains
|
||||
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||
p3.Domains = []string{"smallstep.com"}
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
assert.FatalError(t, p2.Init(config))
|
||||
assert.FatalError(t, p3.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1])
|
||||
assert.FatalError(t, err)
|
||||
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Admin email not in domains
|
||||
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Invalid email
|
||||
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
failDomain, err := generateToken("subject", "the-issuer", p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Invalid tokens
|
||||
parts := strings.Split(t1, ".")
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
// missing key
|
||||
failKey, err := generateSimpleToken("the-issuer", p1.ClientID, key)
|
||||
assert.FatalError(t, err)
|
||||
// invalid token
|
||||
failTok := "foo." + parts[1] + "." + parts[2]
|
||||
// invalid claims
|
||||
failClaims := parts[0] + ".foo." + parts[1]
|
||||
// invalid issuer
|
||||
failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// invalid audience
|
||||
failAud, err := generateSimpleToken("the-issuer", "foobar", &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// invalid signature
|
||||
failSig := t1[0 : len(t1)-2]
|
||||
// expired
|
||||
failExp, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// not before
|
||||
failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", p1, args{t1}, false},
|
||||
{"ok2", p2, args{t2}, false},
|
||||
{"admin", p3, args{t3}, false},
|
||||
{"admin", p3, args{okAdmin}, false},
|
||||
{"fail-email", p3, args{failEmail}, true},
|
||||
{"fail-domain", p3, args{failDomain}, true},
|
||||
{"fail-key", p1, args{failKey}, true},
|
||||
{"fail-token", p1, args{failTok}, true},
|
||||
{"fail-claims", p1, args{failClaims}, true},
|
||||
{"fail-issuer", p1, args{failIss}, true},
|
||||
{"fail-audience", p1, args{failAud}, true},
|
||||
{"fail-signature", p1, args{failSig}, true},
|
||||
{"fail-expired", p1, args{failExp}, true},
|
||||
{"fail-not-before", p1, args{failNbf}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.prov.Authorize(tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Println(tt)
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NotNil(t, got)
|
||||
if tt.name == "admin" {
|
||||
assert.Len(t, 3, got)
|
||||
} else {
|
||||
assert.Len(t, 4, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{
|
||||
globalClaims: &globalProvisionerClaims,
|
||||
DisableRenewal: &disable,
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"disabled", p1, args{t1}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sanitizeEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
want string
|
||||
}{
|
||||
{"equal", "name@smallstep.com", "name@smallstep.com"},
|
||||
{"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"},
|
||||
{"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"},
|
||||
{"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := sanitizeEmail(tt.email); got != tt.want {
|
||||
t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,82 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Interface is the interface that all provisioner types must implement.
|
||||
type Interface interface {
|
||||
GetID() string
|
||||
GetName() string
|
||||
GetType() Type
|
||||
GetEncryptedKey() (kid string, key string, ok bool)
|
||||
Init(config Config) error
|
||||
Authorize(token string) ([]SignOption, error)
|
||||
AuthorizeRenewal(cert *x509.Certificate) error
|
||||
AuthorizeRevoke(token string) error
|
||||
}
|
||||
|
||||
// Type indicates the provisioner Type.
|
||||
type Type int
|
||||
|
||||
const (
|
||||
noopType Type = 0
|
||||
|
||||
// TypeJWK is used to indicate the JWK provisioners.
|
||||
TypeJWK Type = 1
|
||||
|
||||
// TypeOIDC is used to indicate the OIDC provisioners.
|
||||
TypeOIDC Type = 2
|
||||
)
|
||||
|
||||
// Config defines the default parameters used in the initialization of
|
||||
// provisioners.
|
||||
type Config struct {
|
||||
// Claims are the default claims.
|
||||
Claims Claims
|
||||
// Audiences are the audiences used in the default provisioner, (JWK).
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
type provisioner struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// List represents a list of provisioners.
|
||||
type List []Interface
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler and allows to unmarshal a list of a
|
||||
// interfaces into the right type.
|
||||
func (l *List) UnmarshalJSON(data []byte) error {
|
||||
ps := []json.RawMessage{}
|
||||
if err := json.Unmarshal(data, &ps); err != nil {
|
||||
return errors.Wrap(err, "error unmarshaling provisioner list")
|
||||
}
|
||||
|
||||
*l = List{}
|
||||
for _, data := range ps {
|
||||
var typ provisioner
|
||||
if err := json.Unmarshal(data, &typ); err != nil {
|
||||
return errors.Errorf("error unmarshaling provisioner")
|
||||
}
|
||||
var p Interface
|
||||
switch strings.ToLower(typ.Type) {
|
||||
case "jwk":
|
||||
p = &JWK{}
|
||||
case "oidc":
|
||||
p = &OIDC{}
|
||||
default:
|
||||
return errors.Errorf("provisioner type %s not supported", typ.Type)
|
||||
}
|
||||
if err := json.Unmarshal(data, p); err != nil {
|
||||
return errors.Errorf("error unmarshaling provisioner")
|
||||
}
|
||||
*l = append(*l, p)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,233 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
||||
// Options contains the options that can be passed to the Sign method.
|
||||
type Options struct {
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
}
|
||||
|
||||
// SignOption is the interface used to collect all extra options used in the
|
||||
// Sign method.
|
||||
type SignOption interface{}
|
||||
|
||||
// CertificateValidator is the interface used to validate a X.509 certificate.
|
||||
type CertificateValidator interface {
|
||||
SignOption
|
||||
Valid(crt *x509.Certificate) error
|
||||
}
|
||||
|
||||
// CertificateRequestValidator is the interface used to validate a X.509
|
||||
// certificate request.
|
||||
type CertificateRequestValidator interface {
|
||||
SignOption
|
||||
Valid(req *x509.CertificateRequest) error
|
||||
}
|
||||
|
||||
// ProfileModifier is the interface used to add custom options to the profile
|
||||
// constructor. The options are used to modify the final certificate.
|
||||
type ProfileModifier interface {
|
||||
SignOption
|
||||
Option(o Options) x509util.WithOption
|
||||
}
|
||||
|
||||
// profileWithOption is a wrapper against x509util.WithOption to conform the
|
||||
// interface.
|
||||
type profileWithOption x509util.WithOption
|
||||
|
||||
func (v profileWithOption) Option(Options) x509util.WithOption {
|
||||
return x509util.WithOption(v)
|
||||
}
|
||||
|
||||
// profileDefaultDuration is a wrapper against x509util.WithOption to conform the
|
||||
// interface.
|
||||
type profileDefaultDuration time.Duration
|
||||
|
||||
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
|
||||
return x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, time.Duration(v))
|
||||
}
|
||||
|
||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||
// SAN provided is the given email address.
|
||||
type emailOnlyIdentity string
|
||||
|
||||
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||
switch {
|
||||
case len(req.DNSNames) > 0:
|
||||
return errors.New("certificate request cannot contain DNS names")
|
||||
case len(req.IPAddresses) > 0:
|
||||
return errors.New("certificate request cannot contain IP addresses")
|
||||
case len(req.URIs) > 0:
|
||||
return errors.New("certificate request cannot contain URIs")
|
||||
case len(req.EmailAddresses) == 0:
|
||||
return errors.New("certificate request does not contain any email address")
|
||||
case len(req.EmailAddresses) > 1:
|
||||
return errors.New("certificate request does not contain too many email addresses")
|
||||
case req.EmailAddresses[0] == "":
|
||||
return errors.New("certificate request cannot contain an empty email address")
|
||||
case req.EmailAddresses[0] != string(e):
|
||||
return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// commonNameValidator validates the common name of a certificate request.
|
||||
type commonNameValidator string
|
||||
|
||||
// Valid checks that certificate request common name matches the one configured.
|
||||
func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
||||
if req.Subject.CommonName == "" {
|
||||
return errors.New("certificate request cannot contain an empty common name")
|
||||
}
|
||||
if req.Subject.CommonName != string(v) {
|
||||
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
||||
type dnsNamesValidator []string
|
||||
|
||||
// Valid checks that certificate request DNS Names match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
|
||||
want := make(map[string]bool)
|
||||
for _, s := range v {
|
||||
want[s] = true
|
||||
}
|
||||
got := make(map[string]bool)
|
||||
for _, s := range req.DNSNames {
|
||||
got[s] = true
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipAddressesValidator validates the IP addresses SAN of a certificate request.
|
||||
type ipAddressesValidator []net.IP
|
||||
|
||||
// Valid checks that certificate request IP Addresses match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
|
||||
want := make(map[string]bool)
|
||||
for _, ip := range v {
|
||||
want[ip.String()] = true
|
||||
}
|
||||
got := make(map[string]bool)
|
||||
for _, ip := range req.IPAddresses {
|
||||
got[ip.String()] = true
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validityValidator validates the certificate temporal validity settings.
|
||||
type validityValidator struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}
|
||||
|
||||
// newValidityValidator return a new validity validator.
|
||||
func newValidityValidator(min, max time.Duration) *validityValidator {
|
||||
return &validityValidator{min: min, max: max}
|
||||
}
|
||||
|
||||
// Validate validates the certificate temporal validity settings.
|
||||
func (v *validityValidator) Valid(crt *x509.Certificate) error {
|
||||
var (
|
||||
na = crt.NotAfter
|
||||
nb = crt.NotBefore
|
||||
d = na.Sub(nb)
|
||||
now = time.Now()
|
||||
)
|
||||
|
||||
if na.Before(now) {
|
||||
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
||||
}
|
||||
if na.Before(nb) {
|
||||
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
||||
}
|
||||
if d < v.min {
|
||||
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
||||
d, v.min)
|
||||
}
|
||||
if d > v.max {
|
||||
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
||||
d, v.max)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||
)
|
||||
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
}
|
||||
|
||||
type provisionerExtensionOption struct {
|
||||
Type int
|
||||
Name string
|
||||
CredentialID string
|
||||
}
|
||||
|
||||
func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption {
|
||||
return &provisionerExtensionOption{
|
||||
Type: int(typ),
|
||||
Name: name,
|
||||
CredentialID: credentialID,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
crt := p.Subject()
|
||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: typ,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(credentialID),
|
||||
})
|
||||
if err != nil {
|
||||
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
||||
}
|
||||
return pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Avoid deadcode warning in profileWithOption
|
||||
_ = profileWithOption(nil)
|
||||
}
|
@ -0,0 +1,152 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_emailOnlyIdentity_Valid(t *testing.T) {
|
||||
uri, err := url.Parse("https://example.com/1.0/getUser")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e emailOnlyIdentity
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
|
||||
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
|
||||
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
|
||||
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
|
||||
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
|
||||
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
|
||||
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commonNameValidator_Valid(t *testing.T) {
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v commonNameValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||
{"empty", "", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true},
|
||||
{"wrong", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("commonNameValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dnsNamesValidator_Valid(t *testing.T) {
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v dnsNamesValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false},
|
||||
{"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false},
|
||||
{"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false},
|
||||
{"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false},
|
||||
{"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true},
|
||||
{"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true},
|
||||
{"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("dnsNamesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipAddressesValidator_Valid(t *testing.T) {
|
||||
ip1 := net.IPv4(10, 3, 2, 1)
|
||||
ip2 := net.IPv4(10, 3, 2, 2)
|
||||
ip3 := net.IPv4(10, 3, 2, 3)
|
||||
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v ipAddressesValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok0", []net.IP{}, args{&x509.CertificateRequest{IPAddresses: []net.IP{}}}, false},
|
||||
{"ok1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1}}}, false},
|
||||
{"ok2", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip2}}}, false},
|
||||
{"ok3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, false},
|
||||
{"fail1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2}}}, true},
|
||||
{"fail2", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, true},
|
||||
{"fail3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip3}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ipAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validityValidator_Valid(t *testing.T) {
|
||||
type fields struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}
|
||||
type args struct {
|
||||
crt *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validityValidator{
|
||||
min: tt.fields.min,
|
||||
max: tt.fields.max,
|
||||
}
|
||||
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr {
|
||||
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf
|
||||
MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla
|
||||
Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg
|
||||
Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN
|
||||
Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw
|
||||
QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU
|
||||
B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c
|
||||
ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET
|
||||
/A8LXNH4M06A7vE=
|
||||
-----END CERTIFICATE-----
|
@ -0,0 +1,272 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
var testAudiences = []string{
|
||||
"https://ca.smallstep.com/sign",
|
||||
"https://ca.smallsteomcom/1.0/sign",
|
||||
}
|
||||
|
||||
func must(args ...interface{}) []interface{} {
|
||||
if l := len(args); l > 0 && args[l-1] != nil {
|
||||
if err, ok := args[l-1].(error); ok {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fp, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk.KeyID = string(hex.EncodeToString(fp))
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
|
||||
var keySet jose.JSONWebKeySet
|
||||
for i := 0; i < n; i++ {
|
||||
key, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return jose.JSONWebKeySet{}, err
|
||||
}
|
||||
keySet.Keys = append(keySet.Keys, *key)
|
||||
}
|
||||
return keySet, nil
|
||||
}
|
||||
|
||||
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
||||
b, err := json.Marshal(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := new(jose.EncrypterOptions)
|
||||
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||
recipient := jose.Recipient{
|
||||
Algorithm: jose.PBES2_HS256_A128KW,
|
||||
Key: []byte("password"),
|
||||
PBES2Count: jose.PBKDF2Iterations,
|
||||
PBES2Salt: salt,
|
||||
}
|
||||
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypter.Encrypt(b)
|
||||
}
|
||||
|
||||
func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) {
|
||||
enc, err := jose.ParseEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, err := enc.Decrypt([]byte("password"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk := new(jose.JSONWebKey)
|
||||
if err := json.Unmarshal(b, jwk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
func generateJWK() (*JWK, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwe, err := encryptJSONWebKey(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
public := jwk.Public()
|
||||
encrypted, err := jwe.CompactSerialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &JWK{
|
||||
Name: name,
|
||||
Type: "JWK",
|
||||
Key: &public,
|
||||
EncryptedKey: encrypted,
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateOIDC() (*OIDC, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
issuer, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OIDC{
|
||||
Name: name,
|
||||
Type: "OIDC",
|
||||
ClientID: clientID,
|
||||
ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration",
|
||||
Claims: &globalProvisionerClaims,
|
||||
configuration: openIDConfiguration{
|
||||
Issuer: issuer,
|
||||
JWKSetURI: "https://example.com/.well-known/jwks",
|
||||
},
|
||||
keyStore: &keyStore{
|
||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
||||
col := NewCollection(testAudiences)
|
||||
for i := 0; i < nJWK; i++ {
|
||||
p, err := generateJWK()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.Store(p)
|
||||
}
|
||||
for i := 0; i < nOIDC; i++ {
|
||||
p, err := generateOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.Store(p)
|
||||
}
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||
return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk)
|
||||
}
|
||||
|
||||
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id, err := randutil.ASCII(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := struct {
|
||||
jose.Claims
|
||||
Email string `json:"email"`
|
||||
SANS []string `json:"sans"`
|
||||
}{
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: sub,
|
||||
Issuer: iss,
|
||||
IssuedAt: jose.NewNumericDate(iat),
|
||||
NotBefore: jose.NewNumericDate(iat),
|
||||
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||
Audience: []string{aud},
|
||||
},
|
||||
Email: email,
|
||||
SANS: sans,
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
claims := new(jose.Claims)
|
||||
if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tok, claims, nil
|
||||
}
|
||||
|
||||
func generateJWKServer(n int) *httptest.Server {
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}
|
||||
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
|
||||
var ret jose.JSONWebKeySet
|
||||
for _, k := range ks.Keys {
|
||||
ret.Keys = append(ret.Keys, k.Public())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
srv := httptest.NewUnstartedServer(nil)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits.Hits++
|
||||
switch r.RequestURI {
|
||||
case "/error":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/hits":
|
||||
writeJSON(w, hits)
|
||||
case "/openid-configuration", "/.well-known/openid-configuration":
|
||||
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
||||
case "/random":
|
||||
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(keySet))
|
||||
case "/private":
|
||||
writeJSON(w, defaultKeySet)
|
||||
default:
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(defaultKeySet))
|
||||
}
|
||||
})
|
||||
|
||||
srv.Start()
|
||||
return srv
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestProvisionerInit(t *testing.T) {
|
||||
type ProvisionerValidateTest struct {
|
||||
p *Provisioner
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{},
|
||||
err: errors.New("provisioner name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{Name: "foo"},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{Name: "foo", Type: "bar"},
|
||||
err: errors.New("provisioner key cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.p.Init(&globalProvisionerClaims)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue