Merge pull request #51 from smallstep/oidc-provisioner

OIDC provisioners
pull/53/head
Mariano Cano 5 years ago committed by GitHub
commit 095ab891e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,19 +18,19 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/tlsutil"
)
// Authority is the interface implemented by a CA authority.
type Authority interface {
Authorize(ott string) ([]interface{}, error)
Authorize(ott string) ([]provisioner.SignOption, error)
GetTLSOptions() *tlsutil.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error)
GetFederation() ([]*x509.Certificate, error)
@ -161,11 +161,11 @@ type SignRequest struct {
// ProvisionersResponse is the response object that returns the list of
// provisioners.
type ProvisionersResponse struct {
Provisioners []*authority.Provisioner `json:"provisioners"`
NextCursor string `json:"nextCursor"`
Provisioners provisioner.List `json:"provisioners"`
NextCursor string `json:"nextCursor"`
}
// ProvisionerKeyResponse is the response object that returns the encryptoed key
// ProvisionerKeyResponse is the response object that returns the encrypted key
// of a provisioner.
type ProvisionerKeyResponse struct {
Key string `json:"key"`
@ -266,18 +266,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
return
}
signOpts := authority.SignOptions{
opts := provisioner.Options{
NotBefore: body.NotBefore,
NotAfter: body.NotAfter,
}
extraOpts, err := h.Authority.Authorize(body.OTT)
signOpts, err := h.Authority.Authorize(body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
return
}
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, signOpts, extraOpts...)
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, Forbidden(err))
return

@ -24,7 +24,7 @@ import (
"time"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/jose"
@ -410,22 +410,22 @@ func TestSignRequest_Validate(t *testing.T) {
type mockAuthority struct {
ret1, ret2 interface{}
err error
authorize func(ott string) ([]interface{}, error)
authorize func(ott string) ([]provisioner.SignOption, error)
getTLSOptions func() *tlsutil.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
getEncryptedKey func(kid string) (string, error)
getRoots func() ([]*x509.Certificate, error)
getFederation func() ([]*x509.Certificate, error)
}
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
if m.authorize != nil {
return m.authorize(ott)
}
return m.ret1.([]interface{}), m.err
return m.ret1.([]provisioner.SignOption), m.err
}
func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions {
@ -442,9 +442,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) {
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
if m.sign != nil {
return m.sign(cr, signOpts, extraOpts...)
return m.sign(cr, opts, signOpts...)
}
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
}
@ -456,11 +456,11 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
}
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) ([]*authority.Provisioner, string, error) {
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
if m.getProvisioners != nil {
return m.getProvisioners(nextCursor, limit)
}
return m.ret1.([]*authority.Provisioner), m.ret2.(string), m.err
return m.ret1.(provisioner.List), m.ret2.(string), m.err
}
func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
@ -597,7 +597,7 @@ func Test_caHandler_Sign(t *testing.T) {
tests := []struct {
name string
input string
certAttrOpts []interface{}
certAttrOpts []provisioner.SignOption
autherr error
cert *x509.Certificate
root *x509.Certificate
@ -617,7 +617,7 @@ func Test_caHandler_Sign(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorize: func(ott string) ([]interface{}, error) {
authorize: func(ott string) ([]provisioner.SignOption, error) {
return tt.certAttrOpts, tt.autherr
},
getTLSOptions: func() *tlsutil.TLSOptions {
@ -723,14 +723,14 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err)
}
p := []*authority.Provisioner{
{
p := provisioner.List{
&provisioner.JWK{
Type: "JWK",
Name: "max",
EncryptedKey: "abc",
Key: &key,
},
{
&provisioner.JWK{
Type: "JWK",
Name: "mariano",
EncryptedKey: "def",

@ -4,10 +4,10 @@ import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"sync"
"time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util"
)
@ -16,18 +16,14 @@ const legacyAuthority = "step-certificate-authority"
// Authority implements the Certificate Authority internal interface.
type Authority struct {
config *Config
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
ottMap *sync.Map
startTime time.Time
provisionerIDIndex *sync.Map
encryptedKeyIndex *sync.Map
provisionerKeySetIndex *sync.Map
sortedProvisioners provisionerSlice
audiences []string
config *Config
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
ottMap *sync.Map
startTime time.Time
provisioners *provisioner.Collection
// Do not re-initialize
initOnce bool
}
@ -39,31 +35,11 @@ func New(config *Config) (*Authority, error) {
return nil, err
}
// Get sorted provisioners
var sorted provisionerSlice
if config.AuthorityConfig != nil {
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
if err != nil {
return nil, err
}
}
// Define audiences: legacy + possible urls without the ports.
// The CA might have proxies in front so we cannot rely on the port.
audiences := []string{legacyAuthority}
for _, name := range config.DNSNames {
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
}
var a = &Authority{
config: config,
certificates: new(sync.Map),
ottMap: new(sync.Map),
provisionerIDIndex: new(sync.Map),
encryptedKeyIndex: new(sync.Map),
provisionerKeySetIndex: new(sync.Map),
sortedProvisioners: sorted,
audiences: audiences,
config: config,
certificates: new(sync.Map),
ottMap: new(sync.Map),
provisioners: provisioner.NewCollection(config.getAudiences()),
}
if err := a.init(); err != nil {
return nil, err
@ -120,14 +96,15 @@ func (a *Authority) init() error {
}
}
// Store all the provisioners
for _, p := range a.config.AuthorityConfig.Provisioners {
a.provisionerIDIndex.Store(p.ID(), p)
if len(p.EncryptedKey) != 0 {
a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey)
if err := a.provisioners.Store(p); err != nil {
return err
}
}
a.startTime = time.Now()
// JWT numeric dates are seconds.
a.startTime = time.Now().Truncate(time.Second)
// Set flag indicating that initialization has been completed, and should
// not be repeated.
a.initOnce = true

@ -7,6 +7,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
stepJOSE "github.com/smallstep/cli/jose"
)
@ -16,22 +17,22 @@ func testAuthority(t *testing.T) *Authority {
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
disableRenewal := true
p := []*Provisioner{
{
p := provisioner.List{
&provisioner.JWK{
Name: "Max",
Type: "JWK",
Key: maxjwk,
},
{
&provisioner.JWK{
Name: "step-cli",
Type: "JWK",
Key: clijwk,
},
{
&provisioner.JWK{
Name: "dev",
Type: "JWK",
Key: maxjwk,
Claims: &ProvisionerClaims{
Claims: &provisioner.Claims{
DisableRenewal: &disableRenewal,
},
},
@ -113,24 +114,18 @@ func TestAuthorityNew(t *testing.T) {
assert.True(t, auth.initOnce)
assert.NotNil(t, auth.intermediateIdentity)
for _, p := range tc.config.AuthorityConfig.Provisioners {
_p, ok := auth.provisionerIDIndex.Load(p.ID())
_p, ok := auth.provisioners.Load(p.GetID())
assert.True(t, ok)
assert.Equals(t, p, _p)
if len(p.EncryptedKey) > 0 {
key, ok := auth.encryptedKeyIndex.Load(p.Key.KeyID)
if kid, encryptedKey, ok := p.GetEncryptedKey(); ok {
key, ok := auth.provisioners.LoadEncryptedKey(kid)
assert.True(t, ok)
assert.Equals(t, p.EncryptedKey, key)
assert.Equals(t, encryptedKey, key)
}
}
// sanity check
_, ok = auth.provisionerIDIndex.Load("fooo")
_, ok = auth.provisioners.Load("fooo")
assert.False(t, ok)
assert.Equals(t, auth.audiences, []string{
"step-certificate-authority",
"https://127.0.0.1/sign",
"https://127.0.0.1/1.0/sign",
})
}
}
})

@ -2,14 +2,13 @@ package authority
import (
"crypto/x509"
"encoding/asn1"
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/x509util"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
)
type idUsed struct {
@ -17,49 +16,21 @@ type idUsed struct {
Subject string `json:"sub,omitempty"`
}
// Claims extends jwt.Claims with step attributes.
// Claims extends jose.Claims with step attributes.
type Claims struct {
jwt.Claims
SANs []string `json:"sans,omitempty"`
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
jose.Claims
SANs []string `json:"sans,omitempty"`
Email string `json:"email,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
// Authorize authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
func (a *Authority) Authorize(ott string) ([]interface{}, error) {
var (
errContext = map[string]interface{}{"ott": ott}
claims = Claims{}
)
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
var errContext = map[string]interface{}{"ott": ott}
// Validate payload
token, err := jwt.ParseSigned(ott)
token, err := jose.ParseSigned(ott)
if err != nil {
return nil, &apiError{errors.Wrapf(err, "authorize: error parsing token"),
http.StatusUnauthorized, errContext}
@ -68,86 +39,52 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
// Get claims w/out verification. We need to look up the provisioner
// key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner.
var claims Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, &apiError{err, http.StatusUnauthorized, errContext}
}
kid := token.Headers[0].KeyID // JWT will only have 1 header.
if len(kid) == 0 {
return nil, &apiError{errors.New("authorize: token KeyID cannot be empty"),
http.StatusUnauthorized, errContext}
}
pid := claims.Issuer + ":" + kid
val, ok := a.provisionerIDIndex.Load(pid)
if !ok {
return nil, &apiError{errors.Errorf("authorize: provisioner with id %s not found", pid),
http.StatusUnauthorized, errContext}
}
p, ok := val.(*Provisioner)
if !ok {
return nil, &apiError{errors.Errorf("authorize: invalid provisioner type"),
http.StatusInternalServerError, errContext}
}
if err = token.Claims(p.Key, &claims); err != nil {
return nil, &apiError{err, http.StatusUnauthorized, errContext}
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
if err = claims.ValidateWithLeeway(jwt.Expected{
Issuer: p.Name,
}, time.Minute); err != nil {
return nil, &apiError{errors.Wrapf(err, "authorize: invalid token"),
http.StatusUnauthorized, errContext}
}
// Do not accept tokens issued before the start of the ca.
// This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt > 0 && claims.IssuedAt.Time().Before(a.startTime) {
return nil, &apiError{errors.New("token issued before the bootstrap of certificate authority"),
return nil, &apiError{errors.New("authorize: token issued before the bootstrap of certificate authority"),
http.StatusUnauthorized, errContext}
}
}
if !matchesAudience(claims.Audience, a.audiences) {
return nil, &apiError{errors.New("authorize: token audience invalid"), http.StatusUnauthorized,
errContext}
}
if claims.Subject == "" {
return nil, &apiError{errors.New("authorize: token subject cannot be empty"),
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(token, &claims.Claims)
if !ok {
return nil, &apiError{
errors.Errorf("authorize: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")),
http.StatusUnauthorized, errContext}
}
// NOTE: This is for backwards compatibility with older versions of cli
// and certificates. Older versions added the token subject as the only SAN
// in a CSR by default.
if len(claims.SANs) == 0 {
claims.SANs = []string{claims.Subject}
}
dnsNames, ips := x509util.SplitSANs(claims.SANs)
if err != nil {
return nil, err
}
signOps := []interface{}{
&commonNameClaim{claims.Subject},
&dnsNamesClaim{dnsNames},
&ipAddressesClaim{ips},
p,
// Store the token to protect against reuse.
var reuseKey string
switch p.GetType() {
case provisioner.TypeJWK:
reuseKey = claims.ID
case provisioner.TypeOIDC:
reuseKey = claims.Nonce
}
if reuseKey != "" {
if _, ok := a.ottMap.LoadOrStore(reuseKey, &idUsed{
UsedAt: time.Now().Unix(),
Subject: claims.Subject,
}); ok {
return nil, &apiError{errors.Errorf("authorize: token already used"), http.StatusUnauthorized, errContext}
}
}
// Store the token to protect against reuse.
if _, ok := a.ottMap.LoadOrStore(claims.ID, &idUsed{
UsedAt: time.Now().Unix(),
Subject: claims.Subject,
}); ok {
return nil, &apiError{errors.Errorf("token already used"), http.StatusUnauthorized,
errContext}
// Call the provisioner Authorize method to get the signing options
opts, err := p.Authorize(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorize"), http.StatusUnauthorized, errContext}
}
return signOps, nil
return opts, nil
}
// authorizeRenewal tries to locate the step provisioner extension, and checks
@ -157,46 +94,20 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
// TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenewal(crt *x509.Certificate) error {
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
for _, e := range crt.Extensions {
if e.Id.Equal(stepOIDProvisioner) {
var provisioner stepProvisionerASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return &apiError{
err: errors.Wrap(err, "error decoding step provisioner extension"),
code: http.StatusInternalServerError,
context: errContext,
}
}
// Look for the provisioner, if it cannot be found, renewal will not
// be authorized.
pid := string(provisioner.Name) + ":" + string(provisioner.CredentialID)
val, ok := a.provisionerIDIndex.Load(pid)
if !ok {
return &apiError{
err: errors.Errorf("not found: provisioner %s", pid),
code: http.StatusUnauthorized,
context: errContext,
}
}
p, ok := val.(*Provisioner)
if !ok {
return &apiError{
err: errors.Errorf("invalid type: provisioner %s, type %T", pid, val),
code: http.StatusInternalServerError,
context: errContext,
}
}
if p.Claims.IsDisableRenewal() {
return &apiError{
err: errors.Errorf("renew disabled: provisioner %s", pid),
code: http.StatusUnauthorized,
context: errContext,
}
}
return nil
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok {
return &apiError{
err: errors.New("provisioner not found"),
code: http.StatusUnauthorized,
context: errContext,
}
}
if err := p.AuthorizeRenewal(crt); err != nil {
return &apiError{
err: err,
code: http.StatusUnauthorized,
context: errContext,
}
}
return nil
}

@ -7,100 +7,52 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/keys"
stepJOSE "github.com/smallstep/cli/jose"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
)
func TestMatchesAudience(t *testing.T) {
type matchesTest struct {
a, b []string
exp bool
func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
if err != nil {
return "", err
}
tests := map[string]matchesTest{
"false arg1 empty": {
a: []string{},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: false,
},
"false arg2 empty": {
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
b: []string{},
exp: false,
},
"false arg1,arg2 empty": {
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
b: []string{"step-gateway", "step-cli"},
exp: false,
},
"false": {
a: []string{"step-gateway", "step-cli"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: false,
},
"true": {
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: true,
},
"true,portsA": {
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: true,
},
"true,portsB": {
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
exp: true,
},
"true,portsAB": {
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
exp: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
})
}
}
func TestStripPort(t *testing.T) {
type args struct {
rawurl string
id, err := randutil.ASCII(64)
if err != nil {
return "", err
}
tests := []struct {
name string
args args
want string
claims := struct {
jose.Claims
SANS []string `json:"sans"`
}{
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := stripPort(tt.args.rawurl); got != tt.want {
t.Errorf("stripPort() = %v, want %v", got, tt.want)
}
})
Claims: jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
},
SANS: sans,
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func TestAuthorize(t *testing.T) {
a := testAuthority(t)
jwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_priv.jwk",
stepJOSE.WithPassword([]byte("pass")))
assert.FatalError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
// Invalid keys
keyNoKid := &jose.JSONWebKey{Key: key.Key, KeyID: ""}
keyBadKid := &jose.JSONWebKey{Key: key.Key, KeyID: "foo"}
now := time.Now()
validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/sign"}
@ -120,100 +72,37 @@ func TestAuthorize(t *testing.T) {
}
},
"fail empty key id": func(t *testing.T) *authorizeTest {
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT"))
assert.FatalError(t, err)
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyNoKid)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
err: &apiError{errors.New("authorize: token KeyID cannot be empty"),
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
http.StatusUnauthorized, context{"ott": raw}},
}
},
"fail provisioner not found": func(t *testing.T) *authorizeTest {
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
assert.FatalError(t, err)
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyBadKid)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
err: &apiError{errors.New("authorize: provisioner with id step-cli:foo not found"),
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
http.StatusUnauthorized, context{"ott": raw}},
}
},
"fail invalid provisioner": func(t *testing.T) *authorizeTest {
_a := testAuthority(t)
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
assert.FatalError(t, err)
_a.provisionerIDIndex.Store(validIssuer+":foo", "42")
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: _a,
ott: raw,
err: &apiError{errors.New("authorize: invalid provisioner type"),
http.StatusInternalServerError, context{"ott": raw}},
}
},
"fail invalid issuer": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "subject",
Issuer: "invalid-issuer",
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
raw, err := generateToken("test.smallstep.com", "invalid-issuer", validAudience[0], nil, now, key)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
err: &apiError{errors.New("authorize: provisioner with id invalid-issuer:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc not found"),
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
http.StatusUnauthorized, context{"ott": raw}},
}
},
"fail empty subject": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
raw, err := generateToken("", validIssuer, validAudience[0], nil, now, key)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
@ -223,64 +112,34 @@ func TestAuthorize(t *testing.T) {
}
},
"fail verify-sig-failure": func(t *testing.T) *authorizeTest {
_, priv2, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
invalidKeySig, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.ES256,
Key: priv2,
}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
}
raw, err := jwt.Signed(invalidKeySig).Claims(cl).CompactSerialize()
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
err: &apiError{errors.New("square/go-jose: error in cryptographic primitive"),
http.StatusUnauthorized, context{"ott": raw}},
ott: raw + "00",
err: &apiError{errors.New("authorize: error parsing claims: square/go-jose: error in cryptographic primitive"),
http.StatusUnauthorized, context{"ott": raw + "00"}},
}
},
"fail token-already-used": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "42",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
assert.FatalError(t, err)
_, err = a.Authorize(raw)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
err: &apiError{errors.New("token already used"),
err: &apiError{errors.New("authorize: token already used"),
http.StatusUnauthorized, context{"ott": raw}},
}
},
"ok": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
res: []interface{}{"1", "2", "3", "4"},
res: []interface{}{"1", "2", "3", "4", "5", "6"},
}
},
}

@ -1,117 +0,0 @@
package authority
import (
"net"
"reflect"
"time"
"github.com/pkg/errors"
x509 "github.com/smallstep/cli/pkg/x509"
)
// certClaim interface is implemented by types used to validate specific claims in a
// certificate request.
type certClaim interface {
Valid(crt *x509.Certificate) error
}
// ValidateClaims returns nil if all the claims are validated, it will return
// the first error if a claim fails.
func validateClaims(crt *x509.Certificate, claims []certClaim) (err error) {
for _, c := range claims {
if err = c.Valid(crt); err != nil {
return err
}
}
return
}
// commonNameClaim validates the common name of a certificate request.
type commonNameClaim struct {
name string
}
// Valid checks that certificate request common name matches the one configured.
func (c *commonNameClaim) Valid(crt *x509.Certificate) error {
if crt.Subject.CommonName == "" {
return errors.New("common name cannot be empty")
}
if crt.Subject.CommonName != c.name {
return errors.Errorf("common name claim failed - got %s, want %s", crt.Subject.CommonName, c.name)
}
return nil
}
type dnsNamesClaim struct {
names []string
}
// Valid checks that certificate request DNS Names match those configured in
// the bootstrap (token) flow.
func (c *dnsNamesClaim) Valid(crt *x509.Certificate) error {
tokMap := make(map[string]int)
for _, e := range c.names {
tokMap[e] = 1
}
crtMap := make(map[string]int)
for _, e := range crt.DNSNames {
crtMap[e] = 1
}
if !reflect.DeepEqual(tokMap, crtMap) {
return errors.Errorf("DNS names claim failed - got %s, want %s", crt.DNSNames, c.names)
}
return nil
}
type ipAddressesClaim struct {
ips []net.IP
}
// Valid checks that certificate request IP Addresses match those configured in
// the bootstrap (token) flow.
func (c *ipAddressesClaim) Valid(crt *x509.Certificate) error {
tokMap := make(map[string]int)
for _, e := range c.ips {
tokMap[e.String()] = 1
}
crtMap := make(map[string]int)
for _, e := range crt.IPAddresses {
crtMap[e.String()] = 1
}
if !reflect.DeepEqual(tokMap, crtMap) {
return errors.Errorf("IP Addresses claim failed - got %v, want %v", crt.IPAddresses, c.ips)
}
return nil
}
// certTemporalClaim validates the certificate temporal validity settings.
type certTemporalClaim struct {
min time.Duration
max time.Duration
}
// Validate validates the certificate temporal validity settings.
func (ctc *certTemporalClaim) Valid(crt *x509.Certificate) error {
var (
na = crt.NotAfter
nb = crt.NotBefore
d = na.Sub(nb)
now = time.Now()
)
if na.Before(now) {
return errors.Errorf("NotAfter: %v cannot be in the past", na)
}
if na.Before(nb) {
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
}
if d < ctc.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
d, ctc.min)
}
if d > ctc.max {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
d, ctc.max)
}
return nil
}

@ -1,132 +0,0 @@
package authority
import (
"crypto/x509/pkix"
"net"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
x509 "github.com/smallstep/cli/pkg/x509"
)
func TestCommonNameClaim_Valid(t *testing.T) {
tests := map[string]struct {
cnc certClaim
crt *x509.Certificate
err error
}{
"empty-common-name": {
cnc: &commonNameClaim{name: "foo"},
crt: &x509.Certificate{},
err: errors.New("common name cannot be empty"),
},
"wrong-common-name": {
cnc: &commonNameClaim{name: "foo"},
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}},
err: errors.New("common name claim failed - got bar, want foo"),
},
"ok": {
cnc: &commonNameClaim{name: "foo"},
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := tc.cnc.Valid(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestIPAddressesClaim_Valid(t *testing.T) {
tests := map[string]struct {
iac certClaim
crt *x509.Certificate
err error
}{
"unexpected-ip-in-crt": {
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
err: errors.New("IP Addresses claim failed - got [127.0.0.1 1.1.1.1], want [127.0.0.1]"),
},
"missing-ip-in-crt": {
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want [127.0.0.1 1.1.1.1]"),
},
"invalid-matcher-nonempty-ips": {
iac: &ipAddressesClaim{ips: []net.IP{}},
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want []"),
},
"ok": {
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
},
"ok-multiple-identical-ip-entries": {
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1")}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := tc.iac.Valid(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestDNSNamesClaim_Valid(t *testing.T) {
tests := map[string]struct {
dnc certClaim
crt *x509.Certificate
err error
}{
"unexpected-dns-name-in-crt": {
dnc: &dnsNamesClaim{names: []string{"foo"}},
crt: &x509.Certificate{DNSNames: []string{"foo", "bar"}},
err: errors.New("DNS names claim failed - got [foo bar], want [foo]"),
},
"ok": {
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
crt: &x509.Certificate{DNSNames: []string{"bar", "foo"}},
},
"missing-dns-name-in-crt": {
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
crt: &x509.Certificate{DNSNames: []string{"foo"}},
err: errors.New("DNS names claim failed - got [foo], want [foo bar]"),
},
"ok-multiple-identical-dns-entries": {
dnc: &dnsNamesClaim{names: []string{"foo"}},
crt: &x509.Certificate{DNSNames: []string{"foo", "foo", "foo"}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := tc.dnc.Valid(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

@ -2,11 +2,13 @@ package authority
import (
"encoding/json"
"fmt"
"net"
"os"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
)
@ -25,10 +27,10 @@ var (
Renegotiation: false,
}
defaultDisableRenewal = false
globalProvisionerClaims = ProvisionerClaims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
globalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{5 * time.Minute},
MaxTLSDur: &provisioner.Duration{24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}
)
@ -50,16 +52,15 @@ type Config struct {
// AuthConfig represents the configuration options for the authority.
type AuthConfig struct {
Provisioners []*Provisioner `json:"provisioners,omitempty"`
Template *x509util.ASN1DN `json:"template,omitempty"`
Claims *ProvisionerClaims `json:"claims,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
Provisioners provisioner.List `json:"provisioners"`
Template *x509util.ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
}
// Validate validates the authority configuration.
func (c *AuthConfig) Validate() error {
func (c *AuthConfig) Validate(audiences []string) error {
var err error
if c == nil {
return errors.New("authority cannot be undefined")
}
@ -70,11 +71,18 @@ func (c *AuthConfig) Validate() error {
if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil {
return err
}
// Initialize provisioners
config := provisioner.Config{
Claims: *c.Claims,
Audiences: audiences,
}
for _, p := range c.Provisioners {
if err := p.Init(c.Claims); err != nil {
if err := p.Init(config); err != nil {
return err
}
}
if c.Template == nil {
c.Template = &x509util.ASN1DN{}
}
@ -153,5 +161,16 @@ func (c *Config) Validate() error {
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
}
return c.AuthorityConfig.Validate()
return c.AuthorityConfig.Validate(c.getAudiences())
}
// getAudiences returns the legacy and possible urls without the ports that will
// be used as the default provisioner audiences. The CA might have proxies in
// front so we cannot rely on the port.
func (c *Config) getAudiences() []string {
audiences := []string{legacyAuthority}
for _, name := range c.DNSNames {
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
}
return audiences
}

@ -5,6 +5,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
stepJOSE "github.com/smallstep/cli/jose"
@ -17,13 +18,13 @@ func TestConfigValidate(t *testing.T) {
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
ac := &AuthConfig{
Provisioners: []*Provisioner{
{
Provisioners: provisioner.List{
&provisioner.JWK{
Name: "Max",
Type: "JWK",
Key: maxjwk,
},
{
&provisioner.JWK{
Name: "step-cli",
Type: "JWK",
Key: clijwk,
@ -229,13 +230,13 @@ func TestAuthConfigValidate(t *testing.T) {
assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
p := []*Provisioner{
{
p := provisioner.List{
&provisioner.JWK{
Name: "Max",
Type: "JWK",
Key: maxjwk,
},
{
&provisioner.JWK{
Name: "step-cli",
Type: "JWK",
Key: clijwk,
@ -263,9 +264,9 @@ func TestAuthConfigValidate(t *testing.T) {
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: []*Provisioner{
{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
{Name: "foo", Key: &jose.JSONWebKey{}},
Provisioners: provisioner.List{
&provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
&provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}},
},
},
err: errors.New("provisioner type cannot be empty"),
@ -293,7 +294,7 @@ func TestAuthConfigValidate(t *testing.T) {
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
err := tc.ac.Validate()
err := tc.ac.Validate([]string{})
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())

@ -1,17 +1,14 @@
package authority
package provisioner
import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/x509util"
jose "gopkg.in/square/go-jose.v2"
)
// ProvisionerClaims so that individual provisioners can override global claims.
type ProvisionerClaims struct {
globalClaims *ProvisionerClaims
// Claims so that individual provisioners can override global claims.
type Claims struct {
globalClaims *Claims
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
@ -19,19 +16,18 @@ type ProvisionerClaims struct {
}
// Init initializes and validates the individual provisioner claims.
func (pc *ProvisionerClaims) Init(global *ProvisionerClaims) (*ProvisionerClaims, error) {
func (pc *Claims) Init(global *Claims) (*Claims, error) {
if pc == nil {
pc = &ProvisionerClaims{}
pc = &Claims{}
}
pc.globalClaims = global
err := pc.Validate()
return pc, err
return pc, pc.Validate()
}
// DefaultTLSCertDuration returns the default TLS cert duration for the
// provisioner. If the default is not set within the provisioner, then the global
// default from the authority configuration will be used.
func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
func (pc *Claims) DefaultTLSCertDuration() time.Duration {
if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 {
return pc.globalClaims.DefaultTLSCertDuration()
}
@ -41,7 +37,7 @@ func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
// If the minimum is not set within the provisioner, then the global
// minimum from the authority configuration will be used.
func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
func (pc *Claims) MinTLSCertDuration() time.Duration {
if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 {
return pc.globalClaims.MinTLSCertDuration()
}
@ -51,7 +47,7 @@ func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
// If the maximum is not set within the provisioner, then the global
// maximum from the authority configuration will be used.
func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
func (pc *Claims) MaxTLSCertDuration() time.Duration {
if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 {
return pc.globalClaims.MaxTLSCertDuration()
}
@ -61,7 +57,7 @@ func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
// IsDisableRenewal returns if the renewal flow is disabled for the
// provisioner. If the property is not set within the provisioner, then the
// global value from the authority configuration will be used.
func (pc *ProvisionerClaims) IsDisableRenewal() bool {
func (pc *Claims) IsDisableRenewal() bool {
if pc.DisableRenewal == nil {
return pc.globalClaims.IsDisableRenewal()
}
@ -69,7 +65,7 @@ func (pc *ProvisionerClaims) IsDisableRenewal() bool {
}
// Validate validates and modifies the Claims with default values.
func (pc *ProvisionerClaims) Validate() error {
func (pc *Claims) Validate() error {
var (
min = pc.MinTLSCertDuration()
max = pc.MaxTLSCertDuration()
@ -93,52 +89,3 @@ func (pc *ProvisionerClaims) Validate() error {
return nil
}
}
// Provisioner - authorized entity that can sign tokens necessary for signature requests.
type Provisioner struct {
Name string `json:"name,omitempty"`
Type string `json:"type,omitempty"`
Key *jose.JSONWebKey `json:"key,omitempty"`
EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *ProvisionerClaims `json:"claims,omitempty"`
}
// Init initializes and validates a the fields of Provisioner type.
func (p *Provisioner) Init(global *ProvisionerClaims) error {
switch {
case p.Name == "":
return errors.New("provisioner name cannot be empty")
case p.Type == "":
return errors.New("provisioner type cannot be empty")
case p.Key == nil:
return errors.New("provisioner key cannot be empty")
}
var err error
p.Claims, err = p.Claims.Init(global)
return err
}
// getTLSApps returns a list of modifiers and validators that will be applied to
// the certificate.
func (p *Provisioner) getTLSApps(so SignOptions) ([]x509util.WithOption, []certClaim, error) {
c := p.Claims
return []x509util.WithOption{
x509util.WithNotBeforeAfterDuration(so.NotBefore,
so.NotAfter, c.DefaultTLSCertDuration()),
withProvisionerOID(p.Name, p.Key.KeyID),
}, []certClaim{
&certTemporalClaim{
min: c.MinTLSCertDuration(),
max: c.MaxTLSCertDuration(),
},
}, nil
}
// ID returns the provisioner identifier. The name and credential id should
// uniquely identify any provisioner.
func (p *Provisioner) ID() string {
return p.Name + ":" + p.Key.KeyID
}

@ -0,0 +1,212 @@
package provisioner
import (
"crypto/sha1"
"crypto/x509"
"encoding/asn1"
"encoding/binary"
"encoding/hex"
"fmt"
"net/url"
"sort"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
type uidProvisioner struct {
provisioner Interface
uid string
}
type provisionerSlice []uidProvisioner
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Collection is a memory map of provisioners.
type Collection struct {
byID *sync.Map
byKey *sync.Map
sorted provisionerSlice
audiences []string
}
// NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner.
func NewCollection(audiences []string) *Collection {
return &Collection{
byID: new(sync.Map),
byKey: new(sync.Map),
audiences: audiences,
}
}
// Load a provisioner by the ID.
func (c *Collection) Load(id string) (Interface, bool) {
return loadProvisioner(c.byID, id)
}
// LoadByToken parses the token claims and loads the provisioner associated.
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
// match with server audiences
if matchesAudience(claims.Audience, c.audiences) {
// If matches with stored audiences it will be a JWT token (default), and
// the id would be <issuer>:<kid>.
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
}
// The ID will be just the clientID stored in azp or aud.
var payload openIDPayload
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
return nil, false
}
// audience is required
if len(payload.Audience) == 0 {
return nil, false
}
if len(payload.AuthorizedParty) > 0 {
return c.Load(payload.AuthorizedParty)
}
return c.Load(payload.Audience[0])
}
// LoadByCertificate looks for the provisioner extension and extracts the
// proper id to load the provisioner.
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(stepOIDProvisioner) {
var provisioner stepProvisionerASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
if provisioner.Type == int(TypeJWK) {
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
}
return c.Load(string(provisioner.CredentialID))
}
}
// Default to noop provisioner if an extension is not found. This allows to
// accept a renewal of a cert without the provisioner extension.
return &noop{}, true
}
// LoadEncryptedKey returns an encrypted key by indexed by KeyID. At this moment
// only JWK encrypted keys are indexed by KeyID.
func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
p, ok := loadProvisioner(c.byKey, keyID)
if !ok {
return "", false
}
_, key, ok := p.GetEncryptedKey()
return key, ok
}
// Store adds a provisioner to the collection and enforces the uniqueness of
// provisioner IDs.
func (c *Collection) Store(p Interface) error {
// Store provisioner always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
return errors.New("cannot add multiple provisioners with the same id")
}
// Store provisioner in byKey if EncryptedKey is defined.
if kid, _, ok := p.GetEncryptedKey(); ok {
c.byKey.Store(kid, p)
}
// Store sorted provisioners.
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
bi := make([]byte, 4)
sum := provisionerSum(p)
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
c.sorted = append(c.sorted, uidProvisioner{
provisioner: p,
uid: hex.EncodeToString(sum),
})
sort.Sort(c.sorted)
return nil
}
// Find implements pagination on a list of sorted provisioners.
func (c *Collection) Find(cursor string, limit int) (List, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
}
n := c.sorted.Len()
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
slice := List{}
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, c.sorted[i].provisioner)
}
if i < n {
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
}
return slice, ""
}
func loadProvisioner(m *sync.Map, key string) (Interface, bool) {
i, ok := m.Load(key)
if !ok {
return nil, false
}
p, ok := i.(Interface)
if !ok {
return nil, false
}
return p, true
}
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
// create the unique and sorted id.
func provisionerSum(p Interface) []byte {
sum := sha1.Sum([]byte(p.GetID()))
return sum[:]
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
}

@ -0,0 +1,390 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"reflect"
"strings"
"sync"
"testing"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func TestCollection_Load(t *testing.T) {
p, err := generateJWK()
assert.FatalError(t, err)
byID := new(sync.Map)
byID.Store(p.GetID(), p)
byID.Store("string", "a-string")
type fields struct {
byID *sync.Map
}
type args struct {
id string
}
tests := []struct {
name string
fields fields
args args
want Interface
want1 bool
}{
{"ok", fields{byID}, args{p.GetID()}, p, true},
{"fail", fields{byID}, args{"fail"}, nil, false},
{"invalid", fields{byID}, args{"string"}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Collection{
byID: tt.fields.byID,
}
got, got1 := c.Load(tt.args.id)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_LoadByToken(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
byID := new(sync.Map)
byID.Store(p1.GetID(), p1)
byID.Store(p2.GetID(), p2)
byID.Store(p3.GetID(), p3)
byID.Store("string", "a-string")
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk)
assert.FatalError(t, err)
t1, c1, err := parseToken(token)
assert.FatalError(t, err)
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
assert.FatalError(t, err)
token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk)
assert.FatalError(t, err)
t2, c2, err := parseToken(token)
assert.FatalError(t, err)
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
t3, c3, err := parseToken(token)
assert.FatalError(t, err)
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
t4, c4, err := parseToken(token)
assert.FatalError(t, err)
type fields struct {
byID *sync.Map
audiences []string
}
type args struct {
token *jose.JSONWebToken
claims *jose.Claims
}
tests := []struct {
name string
fields fields
args args
want Interface
want1 bool
}{
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Collection{
byID: tt.fields.byID,
audiences: tt.fields.audiences,
}
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_LoadByCertificate(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
byID := new(sync.Map)
byID.Store(p1.GetID(), p1)
byID.Store(p2.GetID(), p2)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err)
ok1Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok1Ext},
}
ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext},
}
notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt},
}
badCert := &x509.Certificate{
Extensions: []pkix.Extension{
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
},
}
type fields struct {
byID *sync.Map
audiences []string
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
want Interface
want1 bool
}{
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Collection{
byID: tt.fields.byID,
audiences: tt.fields.audiences,
}
got, got1 := c.LoadByCertificate(tt.args.cert)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_LoadEncryptedKey(t *testing.T) {
c := NewCollection(testAudiences)
p1, err := generateJWK()
assert.FatalError(t, err)
assert.FatalError(t, c.Store(p1))
p2, err := generateOIDC()
assert.FatalError(t, err)
assert.FatalError(t, c.Store(p2))
// Add oidc in byKey.
// It should not happen.
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
c.byKey.Store(p2KeyID, p2)
type args struct {
keyID string
}
tests := []struct {
name string
args args
want string
want1 bool
}{
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
{"oidc", args{p2KeyID}, "", false},
{"notFound", args{"not-found"}, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
if got != tt.want {
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_Store(t *testing.T) {
c := NewCollection(testAudiences)
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
type args struct {
p Interface
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok1", args{p1}, false},
{"ok2", args{p2}, false},
{"fail1", args{p1}, true},
{"fail2", args{p2}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCollection_Find(t *testing.T) {
c, err := generateCollection(10, 10)
assert.FatalError(t, err)
trim := func(s string) string {
return strings.TrimLeft(s, "0")
}
toList := func(ps provisionerSlice) List {
l := List{}
for _, p := range ps {
l = append(l, p.provisioner)
}
return l
}
type args struct {
cursor string
limit int
}
tests := []struct {
name string
args args
want List
want1 string
}{
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func Test_matchesAudience(t *testing.T) {
type matchesTest struct {
a, b []string
exp bool
}
tests := map[string]matchesTest{
"false arg1 empty": {
a: []string{},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: false,
},
"false arg2 empty": {
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
b: []string{},
exp: false,
},
"false arg1,arg2 empty": {
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
b: []string{"step-gateway", "step-cli"},
exp: false,
},
"false": {
a: []string{"step-gateway", "step-cli"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: false,
},
"true": {
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: true,
},
"true,portsA": {
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
exp: true,
},
"true,portsB": {
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
exp: true,
},
"true,portsAB": {
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
exp: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
})
}
}
func Test_stripPort(t *testing.T) {
type args struct {
rawurl string
}
tests := []struct {
name string
args args
want string
}{
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := stripPort(tt.args.rawurl); got != tt.want {
t.Errorf("stripPort() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,45 @@
package provisioner
import (
"encoding/json"
"time"
"github.com/pkg/errors"
)
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
type Duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.Duration.String())
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
var (
s string
_d time.Duration
)
if d == nil {
return errors.New("duration cannot be nil")
}
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshaling %s", data)
}
if _d, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
d.Duration = _d
return
}

@ -0,0 +1,61 @@
package provisioner
import (
"reflect"
"testing"
"time"
)
func TestDuration_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
d *Duration
args args
want *Duration
wantErr bool
}{
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
{"nil", nil, args{nil}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(tt.d, tt.want) {
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
}
})
}
}
func TestDuration_MarshalJSON(t *testing.T) {
tests := []struct {
name string
d *Duration
want []byte
wantErr bool
}{
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,125 @@
package provisioner
import (
"crypto/x509"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose"
)
// jwtPayload extends jwt.Claims with step attributes.
type jwtPayload struct {
jose.Claims
SANs []string `json:"sans,omitempty"`
}
// JWK is the default provisioner, an entity that can sign tokens necessary for
// signature requests.
type JWK struct {
Type string `json:"type"`
Name string `json:"name"`
Key *jose.JSONWebKey `json:"key"`
EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"`
audiences []string
}
// GetID returns the provisioner unique identifier. The name and credential id
// should uniquely identify any JWK provisioner.
func (p *JWK) GetID() string {
return p.Name + ":" + p.Key.KeyID
}
// GetName returns the name of the provisioner.
func (p *JWK) GetName() string {
return p.Name
}
// GetType returns the type of provisioner.
func (p *JWK) GetType() Type {
return TypeJWK
}
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
func (p *JWK) GetEncryptedKey() (string, string, bool) {
return p.Key.KeyID, p.EncryptedKey, len(p.EncryptedKey) > 0
}
// Init initializes and validates the fields of a JWK type.
func (p *JWK) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
case p.Name == "":
return errors.New("provisioner name cannot be empty")
case p.Key == nil:
return errors.New("provisioner key cannot be empty")
}
p.Claims, err = p.Claims.Init(&config.Claims)
p.audiences = config.Audiences
return err
}
// Authorize validates the given token.
func (p *JWK) Authorize(token string) ([]SignOption, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
var claims jwtPayload
if err = jwt.Claims(p.Key, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
if err = claims.ValidateWithLeeway(jose.Expected{
Issuer: p.Name,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
}
if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty")
}
// NOTE: This is for backwards compatibility with older versions of cli
// and certificates. Older versions added the token subject as the only SAN
// in a CSR by default.
if len(claims.SANs) == 0 {
claims.SANs = []string{claims.Subject}
}
dnsNames, ips := x509util.SplitSANs(claims.SANs)
return []SignOption{
commonNameValidator(claims.Subject),
dnsNamesValidator(dnsNames),
ipAddressesValidator(ips),
profileDefaultDuration(p.Claims.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()),
}, nil
}
// AuthorizeRenewal returns an error if the renewal is disabled.
func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
if p.Claims.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
}
return nil
}
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(token string) error {
return errors.New("not implemented")
}

@ -0,0 +1,256 @@
package provisioner
import (
"crypto/x509"
"errors"
"strings"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
var (
defaultDisableRenewal = false
globalProvisionerClaims = Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}
)
func TestJWK_Getters(t *testing.T) {
p, err := generateJWK()
assert.FatalError(t, err)
if got := p.GetID(); got != p.Name+":"+p.Key.KeyID {
t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID)
}
if got := p.GetName(); got != p.Name {
t.Errorf("JWK.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeJWK {
t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK)
}
kid, key, ok := p.GetEncryptedKey()
if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false {
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, p.Key.KeyID, p.EncryptedKey, true)
}
p.EncryptedKey = ""
kid, key, ok = p.GetEncryptedKey()
if kid != p.Key.KeyID || key != "" || ok == true {
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, p.Key.KeyID, "", false)
}
}
func TestJWK_Init(t *testing.T) {
type ProvisionerValidateTest struct {
p *JWK
err error
}
tests := map[string]func(*testing.T) ProvisionerValidateTest{
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{
Type: "JWK",
},
err: errors.New("provisioner name cannot be empty"),
}
},
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{Name: "foo"},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar"},
err: errors.New("provisioner key cannot be empty"),
}
},
"ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
}
},
}
config := Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
err := tc.p.Init(config)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestJWK_Authorize(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
key1, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
key2, err := decryptJSONWebKey(p2.EncryptedKey)
assert.FatalError(t, err)
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
assert.FatalError(t, err)
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2)
assert.FatalError(t, err)
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1)
assert.FatalError(t, err)
// Invalid tokens
parts := strings.Split(t1, ".")
key3, err := generateJSONWebKey()
assert.FatalError(t, err)
// missing key
failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3)
assert.FatalError(t, err)
// invalid token
failTok := "foo." + parts[1] + "." + parts[2]
// invalid claims
failClaims := parts[0] + ".foo." + parts[1]
// invalid issuer
failIss, err := generateSimpleToken("foobar", testAudiences[0], key1)
assert.FatalError(t, err)
// invalid audience
failAud, err := generateSimpleToken(p1.Name, "foobar", key1)
assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
// no subject
failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1)
assert.FatalError(t, err)
// expired
failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
assert.FatalError(t, err)
// not before
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
assert.FatalError(t, err)
// Remove encrypted key for p2
p2.EncryptedKey = ""
type args struct {
token string
}
tests := []struct {
name string
prov *JWK
args args
wantErr bool
}{
{"ok", p1, args{t1}, false},
{"ok-no-encrypted-key", p2, args{t2}, false},
{"ok-no-sans", p1, args{t3}, false},
{"fail-key", p1, args{failKey}, true},
{"fail-token", p1, args{failTok}, true},
{"fail-claims", p1, args{failClaims}, true},
{"fail-issuer", p1, args{failIss}, true},
{"fail-audience", p1, args{failAud}, true},
{"fail-signature", p1, args{failSig}, true},
{"fail-subject", p1, args{failSub}, true},
{"fail-expired", p1, args{failExp}, true},
{"fail-not-before", p1, args{failNbf}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.Authorize(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else {
assert.NotNil(t, got)
assert.Len(t, 6, got)
}
})
}
}
func TestJWK_AuthorizeRenewal(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{
globalClaims: &globalProvisionerClaims,
DisableRenewal: &disable,
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *JWK
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestJWK_AuthorizeRevoke(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
key1, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *JWK
args args
wantErr bool
}{
{"disabled", p1, args{t1}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

@ -0,0 +1,135 @@
package provisioner
import (
"encoding/json"
"math/rand"
"net/http"
"regexp"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
const (
defaultCacheAge = 12 * time.Hour
defaultCacheJitter = 1 * time.Hour
)
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
type keyStore struct {
sync.RWMutex
uri string
keySet jose.JSONWebKeySet
timer *time.Timer
expiry time.Time
jitter time.Duration
}
func newKeyStore(uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(uri)
if err != nil {
return nil, err
}
ks := &keyStore{
uri: uri,
keySet: keys,
expiry: getExpirationTime(age),
jitter: getCacheJitter(age),
}
next := ks.nextReloadDuration(age)
ks.timer = time.AfterFunc(next, ks.reload)
return ks, nil
}
func (ks *keyStore) Close() {
ks.timer.Stop()
}
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
ks.RLock()
// Force reload if expiration has passed
if time.Now().After(ks.expiry) {
ks.RUnlock()
ks.reload()
ks.RLock()
}
keys = ks.keySet.Key(kid)
ks.RUnlock()
return
}
func (ks *keyStore) reload() {
var next time.Duration
keys, age, err := getKeysFromJWKsURI(ks.uri)
if err != nil {
next = ks.nextReloadDuration(ks.jitter / 2)
} else {
ks.Lock()
ks.keySet = keys
ks.expiry = getExpirationTime(age)
ks.jitter = getCacheJitter(age)
next = ks.nextReloadDuration(age)
ks.Unlock()
}
ks.Lock()
ks.timer.Reset(next)
ks.Unlock()
}
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
n := rand.Int63n(int64(ks.jitter))
age -= time.Duration(n)
if age < 0 {
age = 0
}
return age
}
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
var keys jose.JSONWebKeySet
resp, err := http.Get(uri)
if err != nil {
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
}
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
}
func getCacheAge(cacheControl string) time.Duration {
age := defaultCacheAge
if len(cacheControl) > 0 {
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
if len(match) > 0 {
if len(match[0]) == 2 {
maxAge := match[0][1]
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
if err != nil {
return defaultCacheAge
}
age = time.Duration(maxAgeInt) * time.Second
}
}
}
return age
}
func getCacheJitter(age time.Duration) time.Duration {
switch {
case age > time.Hour:
return defaultCacheJitter
default:
return age / 3
}
}
func getExpirationTime(age time.Duration) time.Time {
return time.Now().Truncate(time.Second).Add(age)
}

@ -0,0 +1,121 @@
package provisioner
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func Test_newKeyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
uri string
}
tests := []struct {
name string
args args
want jose.JSONWebKeySet
wantErr bool
}{
{"ok", args{srv.URL}, ks.keySet, false},
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newKeyStore(tt.args.uri)
if (err != nil) != tt.wantErr {
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
if !reflect.DeepEqual(got.keySet, tt.want) {
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
}
got.Close()
}
})
}
}
func Test_keyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/random")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
keySet1 := ks.keySet
ks.RUnlock()
// Check contents
assert.Len(t, 2, keySet1.Keys)
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Wait for rotation
time.Sleep(5 * time.Second)
ks.RLock()
keySet2 := ks.keySet
ks.RUnlock()
if reflect.DeepEqual(keySet1, keySet2) {
t.Error("keyStore did not rotated")
}
// Check contents
assert.Len(t, 2, keySet2.Keys)
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Check hits
resp, err := srv.Client().Get(srv.URL + "/hits")
assert.FatalError(t, err)
hits := struct {
Hits int `json:"hits"`
}{}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(&hits)
assert.FatalError(t, err)
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
}
func Test_keyStore_Get(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
kid string
}
tests := []struct {
name string
ks *keyStore
args args
wantKeys []jose.JSONWebKey
}{
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
}
})
}
}

@ -0,0 +1,37 @@
package provisioner
import "crypto/x509"
// noop provisioners is a provisioner that accepts anything.
type noop struct{}
func (p *noop) GetID() string {
return "noop"
}
func (p *noop) GetName() string {
return "noop"
}
func (p *noop) GetType() Type {
return noopType
}
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
return "", "", false
}
func (p *noop) Init(config Config) error {
return nil
}
func (p *noop) Authorize(token string) ([]SignOption, error) {
return []SignOption{}, nil
}
func (p *noop) AuthorizeRenewal(cert *x509.Certificate) error {
return nil
}
func (p *noop) AuthorizeRevoke(token string) error {
return nil
}

@ -0,0 +1,27 @@
package provisioner
import (
"crypto/x509"
"testing"
"github.com/smallstep/assert"
)
func Test_noop(t *testing.T) {
p := noop{}
assert.Equals(t, "noop", p.GetID())
assert.Equals(t, "noop", p.GetName())
assert.Equals(t, noopType, p.GetType())
assert.Equals(t, nil, p.Init(Config{}))
assert.Equals(t, nil, p.AuthorizeRenewal(&x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke("foo"))
kid, key, ok := p.GetEncryptedKey()
assert.Equals(t, "", kid)
assert.Equals(t, "", key)
assert.Equals(t, false, ok)
sigOptions, err := p.Authorize("foo")
assert.Equals(t, []SignOption{}, sigOptions)
assert.Equals(t, nil, err)
}

@ -0,0 +1,243 @@
package provisioner
import (
"crypto/x509"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
// openIDConfiguration contains the necessary properties in the
// `/.well-known/openid-configuration` document.
type openIDConfiguration struct {
Issuer string `json:"issuer"`
JWKSetURI string `json:"jwks_uri"`
}
// Validate validates the values in a well-known OpenID configuration endpoint.
func (c openIDConfiguration) Validate() error {
switch {
case c.Issuer == "":
return errors.New("issuer cannot be empty")
case c.JWKSetURI == "":
return errors.New("jwks_uri cannot be empty")
default:
return nil
}
}
// openIDPayload represents the fields on the id_token JWT payload.
type openIDPayload struct {
jose.Claims
AtHash string `json:"at_hash"`
AuthorizedParty string `json:"azp"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Hd string `json:"hd"`
Nonce string `json:"nonce"`
}
// OIDC represents an OAuth 2.0 OpenID Connect provider.
//
// ClientSecret is mandatory, but it can be an empty string.
type OIDC struct {
Type string `json:"type"`
Name string `json:"name"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
ConfigurationEndpoint string `json:"configurationEndpoint"`
Admins []string `json:"admins,omitempty"`
Domains []string `json:"domains,omitempty"`
Claims *Claims `json:"claims,omitempty"`
configuration openIDConfiguration
keyStore *keyStore
}
// IsAdmin returns true if the given email is in the Admins whitelist, false
// otherwise.
func (o *OIDC) IsAdmin(email string) bool {
email = sanitizeEmail(email)
for _, e := range o.Admins {
if email == sanitizeEmail(e) {
return true
}
}
return false
}
func sanitizeEmail(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i] + strings.ToLower(email[i:])
}
return email
}
// GetID returns the provisioner unique identifier, the OIDC provisioner the
// uses the clientID for this.
func (o *OIDC) GetID() string {
return o.ClientID
}
// GetName returns the name of the provisioner.
func (o *OIDC) GetName() string {
return o.Name
}
// GetType returns the type of provisioner.
func (o *OIDC) GetType() Type {
return TypeOIDC
}
// GetEncryptedKey is not available in an OIDC provisioner.
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
return "", "", false
}
// Init validates and initializes the OIDC provider.
func (o *OIDC) Init(config Config) (err error) {
switch {
case o.Type == "":
return errors.New("type cannot be empty")
case o.Name == "":
return errors.New("name cannot be empty")
case o.ClientID == "":
return errors.New("clientID cannot be empty")
case o.ConfigurationEndpoint == "":
return errors.New("configurationEndpoint cannot be empty")
}
// Update claims with global ones
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
return err
}
if err := o.configuration.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
}
// Get JWK key set
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
if err != nil {
return err
}
return nil
}
// ValidatePayload validates the given token payload.
func (o *OIDC) ValidatePayload(p openIDPayload) error {
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
// than a few minutes.
if err := p.ValidateWithLeeway(jose.Expected{
Issuer: o.configuration.Issuer,
Audience: jose.Audience{o.ClientID},
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return errors.Wrap(err, "failed to validate payload")
}
// Validate azp if present
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errors.New("failed to validate payload: invalid azp")
}
// Enforce an email claim
if p.Email == "" {
return errors.New("failed to validate payload: email not found")
}
// Validate domains (case-insensitive)
if !o.IsAdmin(p.Email) && len(o.Domains) > 0 {
email := sanitizeEmail(p.Email)
var found bool
for _, d := range o.Domains {
if strings.HasSuffix(email, "@"+strings.ToLower(d)) {
found = true
break
}
}
if !found {
return errors.New("failed to validate payload: email is not allowed")
}
}
return nil
}
// Authorize validates the given token.
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
// Parse claims to get the kid
var claims openIDPayload
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
}
found := false
kid := jwt.Headers[0].KeyID
keys := o.keyStore.Get(kid)
for _, key := range keys {
if err := jwt.Claims(key, &claims); err == nil {
found = true
break
}
}
if !found {
return nil, errors.New("cannot validate token")
}
if err := o.ValidatePayload(claims); err != nil {
return nil, err
}
// Admins should be able to authorize any SAN
if o.IsAdmin(claims.Email) {
return []SignOption{
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
}, nil
}
return []SignOption{
emailOnlyIdentity(claims.Email),
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
}, nil
}
// AuthorizeRenewal returns an error if the renewal is disabled.
func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
if o.Claims.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
}
return nil
}
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (o *OIDC) AuthorizeRevoke(token string) error {
return errors.New("not implemented")
}
func getAndDecode(uri string, v interface{}) error {
resp, err := http.Get(uri)
if err != nil {
return errors.Wrapf(err, "failed to connect to %s", uri)
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
return errors.Wrapf(err, "error reading %s", uri)
}
return nil
}

@ -0,0 +1,327 @@
package provisioner
import (
"crypto/x509"
"fmt"
"strings"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func Test_openIDConfiguration_Validate(t *testing.T) {
type fields struct {
Issuer string
JWKSetURI string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{"the-issuer", "the-jwks-uri"}, false},
{"no-issuer", fields{"", "the-jwks-uri"}, true},
{"no-jwks-uri", fields{"the-issuer", ""}, true},
{"empty", fields{"", ""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := openIDConfiguration{
Issuer: tt.fields.Issuer,
JWKSetURI: tt.fields.JWKSetURI,
}
if err := c.Validate(); (err != nil) != tt.wantErr {
t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestOIDC_Getters(t *testing.T) {
p, err := generateOIDC()
assert.FatalError(t, err)
if got := p.GetID(); got != p.ClientID {
t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID)
}
if got := p.GetName(); got != p.Name {
t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeOIDC {
t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestOIDC_Init(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
config := Config{
Claims: globalProvisionerClaims,
}
type fields struct {
Type string
Name string
ClientID string
ClientSecret string
ConfigurationEndpoint string
Claims *Claims
Admins []string
Domains []string
}
type args struct {
config Config
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false},
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false},
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &OIDC{
Type: tt.fields.Type,
Name: tt.fields.Name,
ClientID: tt.fields.ClientID,
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
Claims: tt.fields.Claims,
Admins: tt.fields.Admins,
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr == false {
assert.Len(t, 2, p.keyStore.keySet.Keys)
assert.Equals(t, openIDConfiguration{
Issuer: "the-issuer",
JWKSetURI: srv.URL + "/jwks_uri",
}, p.configuration)
}
})
}
}
func TestOIDC_Authorize(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p2.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1])
assert.FatalError(t, err)
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
failDomain, err := generateToken("subject", "the-issuer", p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid tokens
parts := strings.Split(t1, ".")
key, err := generateJSONWebKey()
assert.FatalError(t, err)
// missing key
failKey, err := generateSimpleToken("the-issuer", p1.ClientID, key)
assert.FatalError(t, err)
// invalid token
failTok := "foo." + parts[1] + "." + parts[2]
// invalid claims
failClaims := parts[0] + ".foo." + parts[1]
// invalid issuer
failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// invalid audience
failAud, err := generateSimpleToken("the-issuer", "foobar", &keys.Keys[0])
assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
// expired
failExp, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0])
assert.FatalError(t, err)
// not before
failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok1", p1, args{t1}, false},
{"ok2", p2, args{t2}, false},
{"admin", p3, args{t3}, false},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true},
{"fail-domain", p3, args{failDomain}, true},
{"fail-key", p1, args{failKey}, true},
{"fail-token", p1, args{failTok}, true},
{"fail-claims", p1, args{failClaims}, true},
{"fail-issuer", p1, args{failIss}, true},
{"fail-audience", p1, args{failAud}, true},
{"fail-signature", p1, args{failSig}, true},
{"fail-expired", p1, args{failExp}, true},
{"fail-not-before", p1, args{failNbf}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.Authorize(tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else {
assert.NotNil(t, got)
if tt.name == "admin" {
assert.Len(t, 3, got)
} else {
assert.Len(t, 4, got)
}
}
})
}
}
func TestOIDC_AuthorizeRenewal(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{
globalClaims: &globalProvisionerClaims,
DisableRenewal: &disable,
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"disabled", p1, args{t1}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_sanitizeEmail(t *testing.T) {
tests := []struct {
name string
email string
want string
}{
{"equal", "name@smallstep.com", "name@smallstep.com"},
{"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"},
{"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"},
{"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := sanitizeEmail(tt.email); got != tt.want {
t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,82 @@
package provisioner
import (
"crypto/x509"
"encoding/json"
"strings"
"github.com/pkg/errors"
)
// Interface is the interface that all provisioner types must implement.
type Interface interface {
GetID() string
GetName() string
GetType() Type
GetEncryptedKey() (kid string, key string, ok bool)
Init(config Config) error
Authorize(token string) ([]SignOption, error)
AuthorizeRenewal(cert *x509.Certificate) error
AuthorizeRevoke(token string) error
}
// Type indicates the provisioner Type.
type Type int
const (
noopType Type = 0
// TypeJWK is used to indicate the JWK provisioners.
TypeJWK Type = 1
// TypeOIDC is used to indicate the OIDC provisioners.
TypeOIDC Type = 2
)
// Config defines the default parameters used in the initialization of
// provisioners.
type Config struct {
// Claims are the default claims.
Claims Claims
// Audiences are the audiences used in the default provisioner, (JWK).
Audiences []string
}
type provisioner struct {
Type string `json:"type"`
}
// List represents a list of provisioners.
type List []Interface
// UnmarshalJSON implements json.Unmarshaler and allows to unmarshal a list of a
// interfaces into the right type.
func (l *List) UnmarshalJSON(data []byte) error {
ps := []json.RawMessage{}
if err := json.Unmarshal(data, &ps); err != nil {
return errors.Wrap(err, "error unmarshaling provisioner list")
}
*l = List{}
for _, data := range ps {
var typ provisioner
if err := json.Unmarshal(data, &typ); err != nil {
return errors.Errorf("error unmarshaling provisioner")
}
var p Interface
switch strings.ToLower(typ.Type) {
case "jwk":
p = &JWK{}
case "oidc":
p = &OIDC{}
default:
return errors.Errorf("provisioner type %s not supported", typ.Type)
}
if err := json.Unmarshal(data, p); err != nil {
return errors.Errorf("error unmarshaling provisioner")
}
*l = append(*l, p)
}
return nil
}

@ -0,0 +1,233 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"net"
"reflect"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/x509util"
)
// Options contains the options that can be passed to the Sign method.
type Options struct {
NotAfter time.Time `json:"notAfter"`
NotBefore time.Time `json:"notBefore"`
}
// SignOption is the interface used to collect all extra options used in the
// Sign method.
type SignOption interface{}
// CertificateValidator is the interface used to validate a X.509 certificate.
type CertificateValidator interface {
SignOption
Valid(crt *x509.Certificate) error
}
// CertificateRequestValidator is the interface used to validate a X.509
// certificate request.
type CertificateRequestValidator interface {
SignOption
Valid(req *x509.CertificateRequest) error
}
// ProfileModifier is the interface used to add custom options to the profile
// constructor. The options are used to modify the final certificate.
type ProfileModifier interface {
SignOption
Option(o Options) x509util.WithOption
}
// profileWithOption is a wrapper against x509util.WithOption to conform the
// interface.
type profileWithOption x509util.WithOption
func (v profileWithOption) Option(Options) x509util.WithOption {
return x509util.WithOption(v)
}
// profileDefaultDuration is a wrapper against x509util.WithOption to conform the
// interface.
type profileDefaultDuration time.Duration
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
return x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, time.Duration(v))
}
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
// SAN provided is the given email address.
type emailOnlyIdentity string
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
switch {
case len(req.DNSNames) > 0:
return errors.New("certificate request cannot contain DNS names")
case len(req.IPAddresses) > 0:
return errors.New("certificate request cannot contain IP addresses")
case len(req.URIs) > 0:
return errors.New("certificate request cannot contain URIs")
case len(req.EmailAddresses) == 0:
return errors.New("certificate request does not contain any email address")
case len(req.EmailAddresses) > 1:
return errors.New("certificate request does not contain too many email addresses")
case req.EmailAddresses[0] == "":
return errors.New("certificate request cannot contain an empty email address")
case req.EmailAddresses[0] != string(e):
return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e)
default:
return nil
}
}
// commonNameValidator validates the common name of a certificate request.
type commonNameValidator string
// Valid checks that certificate request common name matches the one configured.
func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
if req.Subject.CommonName == "" {
return errors.New("certificate request cannot contain an empty common name")
}
if req.Subject.CommonName != string(v) {
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
}
return nil
}
// dnsNamesValidator validates the DNS names SAN of a certificate request.
type dnsNamesValidator []string
// Valid checks that certificate request DNS Names match those configured in
// the bootstrap (token) flow.
func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
want := make(map[string]bool)
for _, s := range v {
want[s] = true
}
got := make(map[string]bool)
for _, s := range req.DNSNames {
got[s] = true
}
if !reflect.DeepEqual(want, got) {
return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
}
return nil
}
// ipAddressesValidator validates the IP addresses SAN of a certificate request.
type ipAddressesValidator []net.IP
// Valid checks that certificate request IP Addresses match those configured in
// the bootstrap (token) flow.
func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
want := make(map[string]bool)
for _, ip := range v {
want[ip.String()] = true
}
got := make(map[string]bool)
for _, ip := range req.IPAddresses {
got[ip.String()] = true
}
if !reflect.DeepEqual(want, got) {
return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v)
}
return nil
}
// validityValidator validates the certificate temporal validity settings.
type validityValidator struct {
min time.Duration
max time.Duration
}
// newValidityValidator return a new validity validator.
func newValidityValidator(min, max time.Duration) *validityValidator {
return &validityValidator{min: min, max: max}
}
// Validate validates the certificate temporal validity settings.
func (v *validityValidator) Valid(crt *x509.Certificate) error {
var (
na = crt.NotAfter
nb = crt.NotBefore
d = na.Sub(nb)
now = time.Now()
)
if na.Before(now) {
return errors.Errorf("NotAfter: %v cannot be in the past", na)
}
if na.Before(nb) {
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
}
if d < v.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
d, v.min)
}
if d > v.max {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
d, v.max)
}
return nil
}
var (
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
)
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
}
type provisionerExtensionOption struct {
Type int
Name string
CredentialID string
}
func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption {
return &provisionerExtensionOption{
Type: int(typ),
Name: name,
CredentialID: credentialID,
}
}
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
return func(p x509util.Profile) error {
crt := p.Subject()
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
if err != nil {
return err
}
crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
return nil
}
}
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
b, err := asn1.Marshal(stepProvisionerASN1{
Type: typ,
Name: []byte(name),
CredentialID: []byte(credentialID),
})
if err != nil {
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
}
return pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
}, nil
}
func init() {
// Avoid deadcode warning in profileWithOption
_ = profileWithOption(nil)
}

@ -0,0 +1,152 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"net"
"net/url"
"testing"
"time"
)
func Test_emailOnlyIdentity_Valid(t *testing.T) {
uri, err := url.Parse("https://example.com/1.0/getUser")
if err != nil {
t.Fatal(err)
}
type args struct {
req *x509.CertificateRequest
}
tests := []struct {
name string
e emailOnlyIdentity
args args
wantErr bool
}{
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_commonNameValidator_Valid(t *testing.T) {
type args struct {
req *x509.CertificateRequest
}
tests := []struct {
name string
v commonNameValidator
args args
wantErr bool
}{
{"ok", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
{"empty", "", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true},
{"wrong", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("commonNameValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_dnsNamesValidator_Valid(t *testing.T) {
type args struct {
req *x509.CertificateRequest
}
tests := []struct {
name string
v dnsNamesValidator
args args
wantErr bool
}{
{"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false},
{"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false},
{"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false},
{"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false},
{"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true},
{"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true},
{"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("dnsNamesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_ipAddressesValidator_Valid(t *testing.T) {
ip1 := net.IPv4(10, 3, 2, 1)
ip2 := net.IPv4(10, 3, 2, 2)
ip3 := net.IPv4(10, 3, 2, 3)
type args struct {
req *x509.CertificateRequest
}
tests := []struct {
name string
v ipAddressesValidator
args args
wantErr bool
}{
{"ok0", []net.IP{}, args{&x509.CertificateRequest{IPAddresses: []net.IP{}}}, false},
{"ok1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1}}}, false},
{"ok2", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip2}}}, false},
{"ok3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, false},
{"fail1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2}}}, true},
{"fail2", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, true},
{"fail3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip3}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("ipAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_validityValidator_Valid(t *testing.T) {
type fields struct {
min time.Duration
max time.Duration
}
type args struct {
crt *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validityValidator{
min: tt.fields.min,
max: tt.fields.max,
}
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr {
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf
MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla
Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg
Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN
Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw
QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU
B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c
ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET
/A8LXNH4M06A7vE=
-----END CERTIFICATE-----

@ -0,0 +1,272 @@
package provisioner
import (
"crypto"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"time"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
)
var testAudiences = []string{
"https://ca.smallstep.com/sign",
"https://ca.smallsteomcom/1.0/sign",
}
func must(args ...interface{}) []interface{} {
if l := len(args); l > 0 && args[l-1] != nil {
if err, ok := args[l-1].(error); ok {
panic(err)
}
}
return args
}
func generateJSONWebKey() (*jose.JSONWebKey, error) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
if err != nil {
return nil, err
}
fp, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return nil, err
}
jwk.KeyID = string(hex.EncodeToString(fp))
return jwk, nil
}
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
var keySet jose.JSONWebKeySet
for i := 0; i < n; i++ {
key, err := generateJSONWebKey()
if err != nil {
return jose.JSONWebKeySet{}, err
}
keySet.Keys = append(keySet.Keys, *key)
}
return keySet, nil
}
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
b, err := json.Marshal(jwk)
if err != nil {
return nil, err
}
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
if err != nil {
return nil, err
}
opts := new(jose.EncrypterOptions)
opts.WithContentType(jose.ContentType("jwk+json"))
recipient := jose.Recipient{
Algorithm: jose.PBES2_HS256_A128KW,
Key: []byte("password"),
PBES2Count: jose.PBKDF2Iterations,
PBES2Salt: salt,
}
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
if err != nil {
return nil, err
}
return encrypter.Encrypt(b)
}
func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) {
enc, err := jose.ParseEncrypted(key)
if err != nil {
return nil, err
}
b, err := enc.Decrypt([]byte("password"))
if err != nil {
return nil, err
}
jwk := new(jose.JSONWebKey)
if err := json.Unmarshal(b, jwk); err != nil {
return nil, err
}
return jwk, nil
}
func generateJWK() (*JWK, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey()
if err != nil {
return nil, err
}
jwe, err := encryptJSONWebKey(jwk)
if err != nil {
return nil, err
}
public := jwk.Public()
encrypted, err := jwe.CompactSerialize()
if err != nil {
return nil, err
}
return &JWK{
Name: name,
Type: "JWK",
Key: &public,
EncryptedKey: encrypted,
Claims: &globalProvisionerClaims,
audiences: testAudiences,
}, nil
}
func generateOIDC() (*OIDC, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
clientID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
issuer, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey()
if err != nil {
return nil, err
}
return &OIDC{
Name: name,
Type: "OIDC",
ClientID: clientID,
ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration",
Claims: &globalProvisionerClaims,
configuration: openIDConfiguration{
Issuer: issuer,
JWKSetURI: "https://example.com/.well-known/jwks",
},
keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
}, nil
}
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
col := NewCollection(testAudiences)
for i := 0; i < nJWK; i++ {
p, err := generateJWK()
if err != nil {
return nil, err
}
col.Store(p)
}
for i := 0; i < nOIDC; i++ {
p, err := generateOIDC()
if err != nil {
return nil, err
}
col.Store(p)
}
return col, nil
}
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk)
}
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
if err != nil {
return "", err
}
id, err := randutil.ASCII(64)
if err != nil {
return "", err
}
claims := struct {
jose.Claims
Email string `json:"email"`
SANS []string `json:"sans"`
}{
Claims: jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
},
Email: email,
SANS: sans,
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
tok, err := jose.ParseSigned(token)
if err != nil {
return nil, nil, err
}
claims := new(jose.Claims)
if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil {
return nil, nil, err
}
return tok, claims, nil
}
func generateJWKServer(n int) *httptest.Server {
hits := struct {
Hits int `json:"hits"`
}{}
writeJSON := func(w http.ResponseWriter, v interface{}) {
b, err := json.Marshal(v)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
var ret jose.JSONWebKeySet
for _, k := range ks.Keys {
ret.Keys = append(ret.Keys, k.Public())
}
return ret
}
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits.Hits++
switch r.RequestURI {
case "/error":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/hits":
writeJSON(w, hits)
case "/openid-configuration", "/.well-known/openid-configuration":
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
case "/random":
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(keySet))
case "/private":
writeJSON(w, defaultKeySet)
default:
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(defaultKeySet))
}
})
srv.Start()
return srv
}

@ -1,55 +0,0 @@
package authority
import (
"errors"
"testing"
"github.com/smallstep/assert"
jose "gopkg.in/square/go-jose.v2"
)
func TestProvisionerInit(t *testing.T) {
type ProvisionerValidateTest struct {
p *Provisioner
err error
}
tests := map[string]func(*testing.T) ProvisionerValidateTest{
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &Provisioner{},
err: errors.New("provisioner name cannot be empty"),
}
},
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &Provisioner{Name: "foo"},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &Provisioner{Name: "foo", Type: "bar"},
err: errors.New("provisioner key cannot be empty"),
}
},
"ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &Provisioner{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
}
},
}
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
err := tc.p.Init(&globalProvisionerClaims)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

@ -1,115 +1,25 @@
package authority
import (
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math"
"net/http"
"sort"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
val, ok := a.encryptedKeyIndex.Load(kid)
key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
http.StatusNotFound, context{}}
}
key, ok := val.(string)
if !ok {
return "", &apiError{errors.Errorf("stored value is not a string"),
http.StatusInternalServerError, context{}}
}
return key, nil
}
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
// with their public keys.
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) {
provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit)
func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) {
provisioners, nextCursor := a.provisioners.Find(cursor, limit)
return provisioners, nextCursor, nil
}
type uidProvisioner struct {
provisioner *Provisioner
uid string
}
func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) {
if len(provisioners) > math.MaxInt32 {
return nil, errors.New("too many provisioners")
}
var slice provisionerSlice
bi := make([]byte, 4)
for i, p := range provisioners {
sum, err := provisionerSum(p)
if err != nil {
return nil, err
}
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
binary.BigEndian.PutUint32(bi, uint32(i))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
slice = append(slice, uidProvisioner{
provisioner: p,
uid: hex.EncodeToString(sum),
})
}
sort.Sort(slice)
return slice, nil
}
type provisionerSlice []uidProvisioner
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
}
n := len(p)
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
var slice []*Provisioner
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, p[i].provisioner)
}
if i < n {
return slice, strings.TrimLeft(p[i].uid, "0")
}
return slice, ""
}
// provisionerSum returns the SHA1 of the json representation of the
// provisioner. From this we will create the unique and sorted id.
func provisionerSum(p *Provisioner) ([]byte, error) {
b, err := json.Marshal(p.Key)
if err != nil {
return nil, errors.Wrap(err, "error marshalling provisioner")
}
sum := sha1.Sum(b)
return sum[:], nil
}

@ -1,16 +1,12 @@
package authority
import (
"encoding/json"
"net/http"
"reflect"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
"github.com/smallstep/certificates/authority/provisioner"
)
func TestGetEncryptedKey(t *testing.T) {
@ -27,7 +23,7 @@ func TestGetEncryptedKey(t *testing.T) {
assert.FatalError(t, err)
return &ek{
a: a,
kid: c.AuthorityConfig.Provisioners[1].Key.KeyID,
kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID,
}
},
"fail-not-found": func(t *testing.T) *ek {
@ -42,19 +38,6 @@ func TestGetEncryptedKey(t *testing.T) {
http.StatusNotFound, context{}},
}
},
"fail-invalid-type-found": func(t *testing.T) *ek {
c, err := LoadConfiguration("../ca/testdata/ca.json")
assert.FatalError(t, err)
a, err := New(c)
assert.FatalError(t, err)
a.encryptedKeyIndex.Store("foo", 5)
return &ek{
a: a,
kid: "foo",
err: &apiError{errors.Errorf("stored value is not a string"),
http.StatusInternalServerError, context{}},
}
},
}
for name, genTestCase := range tests {
@ -75,9 +58,9 @@ func TestGetEncryptedKey(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
val, ok := tc.a.provisionerIDIndex.Load("max:" + tc.kid)
val, ok := tc.a.provisioners.Load("max:" + tc.kid)
assert.Fatal(t, ok)
p, ok := val.(*Provisioner)
p, ok := val.(*provisioner.JWK)
assert.Fatal(t, ok)
assert.Equals(t, p.EncryptedKey, ek)
}
@ -126,102 +109,3 @@ func TestGetProvisioners(t *testing.T) {
})
}
}
func generateProvisioner(t *testing.T) *Provisioner {
name, err := randutil.Alphanumeric(10)
assert.FatalError(t, err)
// Create a new JWK
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
// Encrypt JWK
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
assert.FatalError(t, err)
b, err := json.Marshal(jwk)
assert.FatalError(t, err)
recipient := jose.Recipient{
Algorithm: jose.PBES2_HS256_A128KW,
Key: []byte("password"),
PBES2Count: jose.PBKDF2Iterations,
PBES2Salt: salt,
}
opts := new(jose.EncrypterOptions)
opts.WithContentType(jose.ContentType("jwk+json"))
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
assert.FatalError(t, err)
jwe, err := encrypter.Encrypt(b)
assert.FatalError(t, err)
// get public and encrypted keys
public := jwk.Public()
encrypted, err := jwe.CompactSerialize()
assert.FatalError(t, err)
return &Provisioner{
Name: name,
Type: "JWT",
Key: &public,
EncryptedKey: encrypted,
}
}
func Test_newSortedProvisioners(t *testing.T) {
provisioners := make([]*Provisioner, 20)
for i := range provisioners {
provisioners[i] = generateProvisioner(t)
}
ps, err := newSortedProvisioners(provisioners)
assert.FatalError(t, err)
prev := ""
for i, p := range ps {
if p.uid < prev {
t.Errorf("%s should be less that %s", p.uid, prev)
}
if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
}
prev = p.uid
}
}
func Test_provisionerSlice_Find(t *testing.T) {
trim := func(s string) string {
return strings.TrimLeft(s, "0")
}
provisioners := make([]*Provisioner, 20)
for i := range provisioners {
provisioners[i] = generateProvisioner(t)
}
ps, err := newSortedProvisioners(provisioners)
assert.FatalError(t, err)
type args struct {
cursor string
limit int
}
tests := []struct {
name string
p provisionerSlice
args args
want []*Provisioner
want1 string
}{
{"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
{"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
{"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
{"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
{"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
{"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
{"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
{"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

@ -3,7 +3,6 @@ package authority
import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"net/http"
@ -11,6 +10,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
@ -22,48 +22,7 @@ func (a *Authority) GetTLSOptions() *tlsutil.TLSOptions {
return a.config.TLS
}
// SignOptions contains the options that can be passed to the Authority.Sign
// method.
type SignOptions struct {
NotAfter time.Time `json:"notAfter"`
NotBefore time.Time `json:"notBefore"`
}
var (
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35}
)
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
}
const provisionerTypeJWK = 1
func withProvisionerOID(name, kid string) x509util.WithOption {
return func(p x509util.Profile) error {
crt := p.Subject()
b, err := asn1.Marshal(stepProvisionerASN1{
Type: provisionerTypeJWK,
Name: []byte(name),
CredentialID: []byte(kid),
})
if err != nil {
return err
}
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
})
return nil
}
}
var oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35}
func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
return func(p x509util.Profile) error {
@ -96,28 +55,22 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
}
// Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) {
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
var (
errContext = context{"csr": csr, "signOptions": signOpts}
claims = []certClaim{}
mods = []x509util.WithOption{}
errContext = context{"csr": csr, "signOptions": signOpts}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{}
)
for _, op := range extraOpts {
switch k := op.(type) {
case certClaim:
claims = append(claims, k)
case x509util.WithOption:
mods = append(mods, k)
case *Provisioner:
m, c, err := k.getTLSApps(signOpts)
if err != nil {
return nil, nil, &apiError{err, http.StatusInternalServerError, errContext}
case provisioner.CertificateValidator:
certValidators = append(certValidators, k)
case provisioner.CertificateRequestValidator:
if err := k.Valid(csr); err != nil {
return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
}
mods = append(mods, m...)
mods = append(mods, []x509util.WithOption{
withDefaultASN1DN(a.config.AuthorityConfig.Template),
}...)
claims = append(claims, c...)
case provisioner.ProfileModifier:
mods = append(mods, k.Option(signOpts))
default:
return nil, nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k),
http.StatusInternalServerError, errContext}
@ -137,10 +90,6 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext}
}
if err := validateClaims(leaf.Subject(), claims); err != nil {
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusUnauthorized, errContext}
}
crtBytes, err := leaf.CreateCertificate()
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"),
@ -153,6 +102,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext
http.StatusInternalServerError, errContext}
}
// FIXME: This should be before creating the certificate.
for _, v := range certValidators {
if err := v.Valid(serverCert); err != nil {
return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
}
}
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"),

@ -7,7 +7,6 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"net"
"net/http"
"reflect"
"testing"
@ -15,12 +14,49 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose"
stepx509 "github.com/smallstep/cli/pkg/x509"
)
var (
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
)
const provisionerTypeJWK = 1
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
}
func withProvisionerOID(name, kid string) x509util.WithOption {
return func(p x509util.Profile) error {
crt := p.Subject()
b, err := asn1.Marshal(stepProvisionerASN1{
Type: provisionerTypeJWK,
Name: []byte(name),
CredentialID: []byte(kid),
})
if err != nil {
return err
}
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
})
return nil
}
}
func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateRequest)) *x509.CertificateRequest {
_csr := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: "smallstep test"},
@ -52,24 +88,25 @@ func TestSign(t *testing.T) {
}
nb := time.Now()
signOpts := SignOptions{
signOpts := provisioner.Options{
NotBefore: nb,
NotAfter: nb.Add(time.Minute * 5),
}
p := a.config.AuthorityConfig.Provisioners[1]
extraOpts := []interface{}{
&commonNameClaim{"smallstep test"},
&dnsNamesClaim{[]string{"test.smallstep.com"}},
&ipAddressesClaim{[]net.IP{}},
p,
}
// Create a token to get test extra opts.
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err)
extraOpts, err := a.Authorize(token)
assert.FatalError(t, err)
type signTest struct {
auth *Authority
csr *x509.CertificateRequest
signOpts SignOptions
extraOpts []interface{}
signOpts provisioner.Options
extraOpts []provisioner.SignOption
err *apiError
}
tests := map[string]func(*testing.T) *signTest{
@ -123,7 +160,7 @@ func TestSign(t *testing.T) {
return &signTest{
auth: _a,
csr: csr,
extraOpts: []interface{}{p},
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"),
http.StatusInternalServerError,
@ -133,7 +170,7 @@ func TestSign(t *testing.T) {
},
"fail provisioner duration claim": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
_signOpts := SignOptions{
_signOpts := provisioner.Options{
NotBefore: nb,
NotAfter: nb.Add(time.Hour * 25),
}
@ -157,7 +194,7 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: DNS names claim failed - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
http.StatusUnauthorized,
context{"csr": csr, "signOptions": signOpts},
},
@ -262,7 +299,7 @@ func TestRenew(t *testing.T) {
now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7)
na1 := now
so := &SignOptions{
so := &provisioner.Options{
NotBefore: nb1,
NotAfter: na1,
}
@ -272,7 +309,7 @@ func TestRenew(t *testing.T) {
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
withDefaultASN1DN(a.config.AuthorityConfig.Template),
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].Key.KeyID))
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID))
assert.FatalError(t, err)
crtBytes, err := leaf.CreateCertificate()
assert.FatalError(t, err)
@ -284,7 +321,7 @@ func TestRenew(t *testing.T) {
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
withDefaultASN1DN(a.config.AuthorityConfig.Template),
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID),
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID),
)
assert.FatalError(t, err)
crtBytesNoRenew, err := leafNoRenew.CreateCertificate()
@ -321,7 +358,7 @@ func TestRenew(t *testing.T) {
}
return &renewTest{
crt: crtNoRenew,
err: &apiError{errors.New("renew disabled"),
err: &apiError{errors.New("renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, ctx},
}, nil
},

@ -2,48 +2,10 @@ package authority
import (
"encoding/json"
"time"
"github.com/pkg/errors"
)
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
type Duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.Duration.String())
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
var (
s string
_d time.Duration
)
if d == nil {
return errors.New("duration cannot be nil")
}
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if _d, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
d.Duration = _d
return
}
// multiString represents a type that can be encoded/decoded in JSON as a single
// string or an array of strings.
type multiString []string

@ -3,7 +3,6 @@ package authority
import (
"reflect"
"testing"
"time"
)
func Test_multiString_First(t *testing.T) {
@ -101,57 +100,3 @@ func Test_multiString_UnmarshalJSON(t *testing.T) {
})
}
}
func TestDuration_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
d *Duration
args args
want *Duration
wantErr bool
}{
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
{"nil", nil, args{nil}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(tt.d, tt.want) {
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
}
})
}
}
func Test_duration_MarshalJSON(t *testing.T) {
tests := []struct {
name string
d *Duration
want []byte
wantErr bool
}{
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
}
})
}
}

@ -7,7 +7,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
provisioners "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca"
"github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/randutil"
@ -111,10 +111,12 @@ func loadProvisionerJWKByName(name, caURL, caRoot, passFile string) (key *jose.J
}
for _, provisioner := range provisioners {
if provisioner.Name == name {
key, err = decryptProvisionerJWK(provisioner.EncryptedKey, passFile)
if err == nil {
return
if provisioner.GetName() == name {
if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok {
key, err = decryptProvisionerJWK(encryptedKey, passFile)
if err == nil {
return
}
}
}
}
@ -154,7 +156,7 @@ func getRootCAPath() string {
}
// getProvisioners returns the map of provisioners on the given CA.
func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) {
func getProvisioners(caURL, rootFile string) (provisioners.List, error) {
if len(rootFile) == 0 {
rootFile = getRootCAPath()
}
@ -163,7 +165,7 @@ func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) {
return nil, err
}
cursor := ""
provisioners := []*authority.Provisioner{}
var provisioners provisioners.List
for {
resp, err := client.Provisioners(ca.WithProvisionerCursor(cursor), ca.WithProvisionerLimit(100))
if err != nil {

@ -20,6 +20,7 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil"
@ -389,7 +390,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
}
},
"ok": func(t *testing.T) *ekt {
p := config.AuthorityConfig.Provisioners[2]
p := config.AuthorityConfig.Provisioners[2].(*provisioner.JWK)
return &ekt{
ca: ca,
kid: p.Key.KeyID,

@ -446,7 +446,11 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error)
return nil, nil, errors.Wrap(err, "error generating key")
}
var emails []string
dnsNames, ips := x509util.SplitSANs(claims.SANs)
if claims.Email != "" {
emails = append(emails, claims.Email)
}
template := &x509.CertificateRequest{
Subject: pkix.Name{
@ -455,6 +459,7 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error)
SignatureAlgorithm: x509.ECDSAWithSHA256,
DNSNames: dnsNames,
IPAddresses: ips,
EmailAddresses: emails,
}
csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk)

@ -14,7 +14,7 @@ import (
"time"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
)
const (
@ -391,7 +391,7 @@ func TestClient_Renew(t *testing.T) {
func TestClient_Provisioners(t *testing.T) {
ok := &api.ProvisionersResponse{
Provisioners: []*authority.Provisioner{},
Provisioners: provisioner.List{},
}
internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error"))

@ -2,6 +2,7 @@ package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"log"
@ -143,40 +144,12 @@ intermediate private key.`,
}
app.Action = func(ctx *cli.Context) error {
passFile := ctx.String("password-file")
// If zero cmd line args show help, if >1 cmd line args show error.
if ctx.NArg() == 0 {
return cli.ShowAppHelp(ctx)
}
if err := errs.NumberOfArguments(ctx, 1); err != nil {
return err
}
configFile := ctx.Args().Get(0)
config, err := authority.LoadConfiguration(configFile)
if err != nil {
fatal(err)
}
var password []byte
if passFile != "" {
if password, err = ioutil.ReadFile(passFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", passFile))
}
password = bytes.TrimRightFunc(password, unicode.IsSpace)
}
srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password))
if err != nil {
fatal(err)
}
go ca.StopReloaderHandler(srv)
if err = srv.Run(); err != nil && err != http.ErrServerClosed {
fatal(err)
}
return nil
// Hack to be able to run a the top action as a subcommand
cmd := cli.Command{Name: "start", Action: startAction, Flags: app.Flags}
set := flag.NewFlagSet(app.Name, flag.ContinueOnError)
set.Parse(os.Args)
ctx = cli.NewContext(app, set, nil)
return cmd.Run(ctx)
}
if err := app.Run(os.Args); err != nil {
@ -189,6 +162,43 @@ intermediate private key.`,
}
}
func startAction(ctx *cli.Context) error {
passFile := ctx.String("password-file")
// If zero cmd line args show help, if >1 cmd line args show error.
if ctx.NArg() == 0 {
return cli.ShowAppHelp(ctx)
}
if err := errs.NumberOfArguments(ctx, 1); err != nil {
return err
}
configFile := ctx.Args().Get(0)
config, err := authority.LoadConfiguration(configFile)
if err != nil {
fatal(err)
}
var password []byte
if passFile != "" {
if password, err = ioutil.ReadFile(passFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", passFile))
}
password = bytes.TrimRightFunc(password, unicode.IsSpace)
}
srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password))
if err != nil {
fatal(err)
}
go ca.StopReloaderHandler(srv)
if err = srv.Run(); err != nil && err != http.ErrServerClosed {
fatal(err)
}
return nil
}
// fatal writes the passed error on the standard error and exits with the exit
// code 1. If the environment variable STEPDEBUG is set to 1 it shows the
// stack trace of the error.

Loading…
Cancel
Save