Merge pull request #1660 from smallstep/fix-1637

Check cnf claim with CSR fingerprint
pull/1945/head
Mariano Cano 3 months ago committed by GitHub
commit 9c95d3412c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -89,6 +89,39 @@ func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID)
for k, v := range extraHeaders {
so.WithHeader(jose.HeaderKey(k), v)
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
if err != nil {
return "", err
}
id, err := randutil.ASCII(64)
if err != nil {
return "", err
}
iat := time.Now()
claims := jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
}
return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize()
}
func TestAuthority_authorizeToken(t *testing.T) { func TestAuthority_authorizeToken(t *testing.T) {
a := testAuthority(t) a := testAuthority(t)
@ -510,7 +543,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, 10, len(got)) // number of provisioner.SignOptions returned assert.Equals(t, 11, len(got)) // number of provisioner.SignOptions returned
} }
} }
}) })

@ -493,8 +493,8 @@ func (p *GCP) genHostOptions(_ context.Context, claims *gcpPayload) (SignSSHOpti
return SignSSHOptions{CertType: SSHHostCert}, keyID, principals, sshutil.HostCert, sshutil.DefaultIIDTemplate return SignSSHOptions{CertType: SSHHostCert}, keyID, principals, sshutil.HostCert, sshutil.DefaultIIDTemplate
} }
func FormatServiceAccountUsername(serviceAccountId string) string { func FormatServiceAccountUsername(serviceAccountID string) string {
return fmt.Sprintf("sa_%v", serviceAccountId) return fmt.Sprintf("sa_%v", serviceAccountID)
} }
func (p *GCP) genUserOptions(_ context.Context, claims *gcpPayload) (SignSSHOptions, string, []string, sshutil.CertType, string) { func (p *GCP) genUserOptions(_ context.Context, claims *gcpPayload) (SignSSHOptions, string, []string, sshutil.CertType, string) {

@ -19,8 +19,9 @@ import (
// jwtPayload extends jwt.Claims with step attributes. // jwtPayload extends jwt.Claims with step attributes.
type jwtPayload struct { type jwtPayload struct {
jose.Claims jose.Claims
SANs []string `json:"sans,omitempty"` SANs []string `json:"sans,omitempty"`
Step *stepPayload `json:"step,omitempty"` Step *stepPayload `json:"step,omitempty"`
Confirmation *cnfPayload `json:"cnf,omitempty"`
} }
type stepPayload struct { type stepPayload struct {
@ -28,6 +29,10 @@ type stepPayload struct {
RA *RAInfo `json:"ra,omitempty"` RA *RAInfo `json:"ra,omitempty"`
} }
type cnfPayload struct {
Fingerprint string `json:"x5rt#S256,omitempty"`
}
// JWK is the default provisioner, an entity that can sign tokens necessary for // JWK is the default provisioner, an entity that can sign tokens necessary for
// signature requests. // signature requests.
type JWK struct { type JWK struct {
@ -183,6 +188,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
} }
// Check the fingerprint of the certificate request if given.
var fingerprint string
if claims.Confirmation != nil {
fingerprint = claims.Confirmation.Fingerprint
}
return []SignOption{ return []SignOption{
self, self,
templateOptions, templateOptions,
@ -190,6 +201,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID).WithControllerOptions(p.ctl), newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID).WithControllerOptions(p.ctl),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
csrFingerprintValidator(fingerprint),
commonNameSliceValidator(append([]string{claims.Subject}, claims.SANs...)), commonNameSliceValidator(append([]string{claims.Subject}, claims.SANs...)),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newDefaultSANsValidator(ctx, claims.SANs), newDefaultSANsValidator(ctx, claims.SANs),

@ -13,7 +13,9 @@ import (
"testing" "testing"
"time" "time"
"go.step.sm/crypto/fingerprint"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
@ -247,6 +249,9 @@ func TestJWK_AuthorizeSign(t *testing.T) {
t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{}, time.Now(), key1) t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
t3, err := generateCustomToken("subject", p1.Name, testAudiences.Sign[0], key1, nil, map[string]any{"cnf": map[string]any{"x5rt#S256": "fingerprint"}})
assert.FatalError(t, err)
// invalid signature // invalid signature
failSig := t1[0 : len(t1)-2] failSig := t1[0 : len(t1)-2]
@ -254,12 +259,13 @@ func TestJWK_AuthorizeSign(t *testing.T) {
token string token string
} }
tests := []struct { tests := []struct {
name string name string
prov *JWK prov *JWK
args args args args
code int code int
err error err error
sans []string sans []string
fingerprint string
}{ }{
{ {
name: "fail-signature", name: "fail-signature",
@ -284,6 +290,15 @@ func TestJWK_AuthorizeSign(t *testing.T) {
err: nil, err: nil,
sans: []string{"subject"}, sans: []string{"subject"},
}, },
{
name: "ok-cnf",
prov: p1,
args: args{t3},
code: http.StatusOK,
err: nil,
sans: []string{"subject"},
fingerprint: "fingerprint",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -297,7 +312,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
} }
} else { } else {
if assert.NotNil(t, got) { if assert.NotNil(t, got) {
assert.Equals(t, 10, len(got)) assert.Equals(t, 11, len(got))
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *JWK: case *JWK:
@ -321,6 +336,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
case *WebhookController: case *WebhookController:
case csrFingerprintValidator:
assert.Equals(t, tt.fingerprint, string(v))
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
@ -393,17 +410,6 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
iss, aud := p1.Name, testAudiences.SSHSign[0]
t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
assert.FatalError(t, err)
t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
key, err := generateJSONWebKey() key, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -417,6 +423,39 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
// Calculate fingerprint
sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err)
fp, err := fingerprint.New(sshPub.Marshal(), crypto.SHA256, fingerprint.Base64RawURLFingerprint)
assert.FatalError(t, err)
iss, aud := p1.Name, testAudiences.SSHSign[0]
t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
assert.FatalError(t, err)
t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
assert.FatalError(t, err)
t3, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{
"step": map[string]any{
"ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}},
},
"cnf": map[string]any{"kid": fp},
})
assert.FatalError(t, err)
t4, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{
"step": map[string]any{
"ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}},
},
"cnf": map[string]any{"kid": "bad-fingerprint"},
})
assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
@ -451,9 +490,11 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
{"host-type", p1, args{t2, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-type", p1, args{t2, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-principals", p1, args{t2, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-principals", p1, args{t2, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-options", p1, args{t2, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-options", p1, args{t2, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-cnf", p1, args{t3, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ignore-bad-cnf", p1, args{t4, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-sshCA-disabled", p2, args{"foo", SignSSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false}, {"fail-sshCA-disabled", p2, args{"foo", SignSSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false},
{"fail-signature", p1, args{failSig, SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, {"fail-signature", p1, args{failSig, SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
{"rail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

@ -5,7 +5,10 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/subtle"
"crypto/x509" "crypto/x509"
"encoding/base64"
"encoding/json" "encoding/json"
"net" "net"
"net/http" "net/http"
@ -503,3 +506,21 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
cert.ExtraExtensions = append(cert.ExtraExtensions, ext) cert.ExtraExtensions = append(cert.ExtraExtensions, ext)
return nil return nil
} }
// csrFingerprintValidator is a CertificateRequestValidator that checks the
// fingerprint of the certificate request with the provided one.
type csrFingerprintValidator string
func (s csrFingerprintValidator) Valid(cr *x509.CertificateRequest) error {
if s != "" {
expected, err := base64.RawURLEncoding.DecodeString(string(s))
if err != nil {
return errs.ForbiddenErr(err, "error decoding fingerprint")
}
sum := sha256.Sum256(cr.Raw)
if subtle.ConstantTimeCompare(expected, sum[:]) != 1 {
return errs.Forbidden("certificate request fingerprint does not match %q", s)
}
}
return nil
}

@ -44,6 +44,13 @@ type SSHCertOptionsValidator interface {
Valid(got SignSSHOptions) error Valid(got SignSSHOptions) error
} }
// SSHPublicKeyValidator is the interface used to validate the public key of an
// SSH certificate.
type SSHPublicKeyValidator interface {
SignOption
Valid(got ssh.PublicKey) error
}
// SignSSHOptions contains the options that can be passed to the SignSSH method. // SignSSHOptions contains the options that can be passed to the SignSSH method.
type SignSSHOptions struct { type SignSSHOptions struct {
CertType string `json:"certType"` CertType string `json:"certType"`

@ -767,6 +767,37 @@ func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jw
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID)
for k, v := range extraHeaders {
so.WithHeader(jose.HeaderKey(k), v)
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
if err != nil {
return "", err
}
id, err := randutil.ASCII(64)
if err != nil {
return "", err
}
iat := time.Now()
claims := jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
}
return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize()
}
func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithType("JWT") so.WithType("JWT")

@ -21,9 +21,10 @@ import (
// x5cPayload extends jwt.Claims with step attributes. // x5cPayload extends jwt.Claims with step attributes.
type x5cPayload struct { type x5cPayload struct {
jose.Claims jose.Claims
SANs []string `json:"sans,omitempty"` SANs []string `json:"sans,omitempty"`
Step *stepPayload `json:"step,omitempty"` Step *stepPayload `json:"step,omitempty"`
chains [][]*x509.Certificate Confirmation *cnfPayload `json:"cnf,omitempty"`
chains [][]*x509.Certificate
} }
// X5C is the default provisioner, an entity that can sign tokens necessary for // X5C is the default provisioner, an entity that can sign tokens necessary for
@ -233,6 +234,12 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
} }
// Check the fingerprint of the certificate request if given.
var fingerprint string
if claims.Confirmation != nil {
fingerprint = claims.Confirmation.Fingerprint
}
return []SignOption{ return []SignOption{
self, self,
templateOptions, templateOptions,
@ -243,6 +250,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
x5cLeaf.NotBefore, x5cLeaf.NotAfter, x5cLeaf.NotBefore, x5cLeaf.NotAfter,
}, },
// validators // validators
csrFingerprintValidator(fingerprint),
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
newDefaultSANsValidator(ctx, claims.SANs), newDefaultSANsValidator(ctx, claims.SANs),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},

@ -3,9 +3,11 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"testing" "testing"
"time" "time"
@ -14,13 +16,19 @@ import (
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func assertHasPrefix(t *testing.T, s, p string) bool {
t.Helper()
return assert.True(t, strings.HasPrefix(s, p), "%q is not a prefix of %q", p, s)
}
func TestX5C_Getters(t *testing.T) { func TestX5C_Getters(t *testing.T) {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
id := "x5c/" + p.Name id := "x5c/" + p.Name
if got := p.GetID(); got != id { if got := p.GetID(); got != id {
t.Errorf("X5C.GetID() = %v, want %v:%v", got, p.Name, id) t.Errorf("X5C.GetID() = %v, want %v:%v", got, p.Name, id)
@ -79,7 +87,7 @@ func TestX5C_Init(t *testing.T) {
}, },
"fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
p.Claims = &Claims{DefaultTLSDur: &Duration{0}} p.Claims = &Claims{DefaultTLSDur: &Duration{0}}
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: p, p: p,
@ -88,7 +96,7 @@ func TestX5C_Init(t *testing.T) {
}, },
"ok": func(t *testing.T) ProvisionerValidateTest { "ok": func(t *testing.T) ProvisionerValidateTest {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: p, p: p,
} }
@ -117,7 +125,7 @@ VR0RBA0wC4IJcm9vdC10ZXN0MAoGCCqGSM49BAMCA0kAMEYCIQC2vgqwla0u8LHH
1MHob14qvS5o76HautbIBW7fcHzz5gIhAIx5A2+wkJYX4026kqaZCk/1sAwTxSGY 1MHob14qvS5o76HautbIBW7fcHzz5gIhAIx5A2+wkJYX4026kqaZCk/1sAwTxSGY
M46l92gdOozT M46l92gdOozT
-----END CERTIFICATE-----`)) -----END CERTIFICATE-----`))
assert.FatalError(t, err) require.NoError(t, err)
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: p, p: p,
extraValid: func(p *X5C) error { extraValid: func(p *X5C) error {
@ -143,11 +151,11 @@ M46l92gdOozT
err := tc.p.Init(config) err := tc.p.Init(config)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error()) assert.EqualError(t, tc.err, err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) assert.Equal(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
if tc.extraValid != nil { if tc.extraValid != nil {
assert.Nil(t, tc.extraValid(tc.p)) assert.Nil(t, tc.extraValid(tc.p))
} }
@ -159,9 +167,9 @@ M46l92gdOozT
func TestX5C_authorizeToken(t *testing.T) { func TestX5C_authorizeToken(t *testing.T) {
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err) require.NoError(t, err)
x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err) require.NoError(t, err)
type test struct { type test struct {
p *X5C p *X5C
@ -172,7 +180,7 @@ func TestX5C_authorizeToken(t *testing.T) {
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test { "fail/bad-token": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
@ -192,15 +200,15 @@ DgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNLJ
P9K7MAoGCCqGSM49BAMCA0gAMEUCIQC5c1ldDcesDb31GlO5cEJvOcRrIrNtkk8m P9K7MAoGCCqGSM49BAMCA0gAMEUCIQC5c1ldDcesDb31GlO5cEJvOcRrIrNtkk8m
a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo= a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo=
-----END CERTIFICATE-----`)) -----END CERTIFICATE-----`))
assert.FatalError(t, err) require.NoError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) require.NoError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -231,15 +239,15 @@ BgNVHREECTAHggVsZWFmMjAKBggqhkjOPQQDAgNIADBFAiB7gMRy3t81HpcnoRAS
ELZmDFaEnoLCsVfbmanFykazQQIhAI0sZjoE9t6gvzQp7XQp6CoxzCc3Jv3FwZ8G ELZmDFaEnoLCsVfbmanFykazQQIhAI0sZjoE9t6gvzQp7XQp6CoxzCc3Jv3FwZ8G
EXAHTA9L EXAHTA9L
-----END CERTIFICATE-----`)) -----END CERTIFICATE-----`))
assert.FatalError(t, err) require.NoError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) require.NoError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -272,16 +280,16 @@ E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1
2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC
lgsqsR63is+0YQ== lgsqsR63is+0YQ==
-----END CERTIFICATE-----`)) -----END CERTIFICATE-----`))
assert.FatalError(t, err) require.NoError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) require.NoError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -314,15 +322,15 @@ E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1
2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC
lgsqsR63is+0YQ== lgsqsR63is+0YQ==
-----END CERTIFICATE-----`)) -----END CERTIFICATE-----`))
assert.FatalError(t, err) require.NoError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) require.NoError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", tok, err := generateToken("", "foobar", testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -332,11 +340,11 @@ lgsqsR63is+0YQ==
}, },
"fail/invalid-issuer": func(t *testing.T) test { "fail/invalid-issuer": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", tok, err := generateToken("", "foobar", testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -346,11 +354,11 @@ lgsqsR63is+0YQ==
}, },
"fail/invalid-audience": func(t *testing.T) test { "fail/invalid-audience": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", p.GetName(), "foobar", "", tok, err := generateToken("", p.GetName(), "foobar", "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -360,11 +368,11 @@ lgsqsR63is+0YQ==
}, },
"fail/empty-subject": func(t *testing.T) test { "fail/empty-subject": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -374,11 +382,11 @@ lgsqsR63is+0YQ==
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -392,12 +400,12 @@ lgsqsR63is+0YQ==
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equal(t, tc.code, sc.StatusCode())
} }
assert.HasPrefix(t, err.Error(), tc.err.Error()) assertHasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.NoError(t, tc.err) {
assert.NotNil(t, claims) assert.NotNil(t, claims)
assert.NotNil(t, claims.chains) assert.NotNil(t, claims.chains)
} }
@ -408,21 +416,22 @@ lgsqsR63is+0YQ==
func TestX5C_AuthorizeSign(t *testing.T) { func TestX5C_AuthorizeSign(t *testing.T) {
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err) require.NoError(t, err)
jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err) require.NoError(t, err)
type test struct { type test struct {
p *X5C p *X5C
token string token string
code int code int
err error err error
sans []string sans []string
fingerprint string
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
@ -432,11 +441,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
}, },
"ok/empty-sans": func(t *testing.T) test { "ok/empty-sans": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{}, time.Now(), jwk, []string{}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -445,65 +454,90 @@ func TestX5C_AuthorizeSign(t *testing.T) {
}, },
"ok/multi-sans": func(t *testing.T) test { "ok/multi-sans": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"127.0.0.1", "foo", "max@smallstep.com"}, time.Now(), jwk, []string{"127.0.0.1", "foo", "max@smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, sans: []string{"127.0.0.1", "foo", "max@smallstep.com"},
} }
}, },
"ok/cnf": func(t *testing.T) test {
p, err := generateX5C(nil)
require.NoError(t, err)
x5c := make([]string, len(certs))
for i, cert := range certs {
x5c[i] = base64.StdEncoding.EncodeToString(cert.Raw)
}
extraHeaders := map[string]any{"x5c": x5c}
extraClaims := map[string]any{
"sans": []string{"127.0.0.1", "foo", "max@smallstep.com"},
"cnf": map[string]any{"x5rt#S256": "fingerprint"},
}
tok, err := generateCustomToken("foo", p.GetName(), testAudiences.Sign[0], jwk, extraHeaders, extraClaims)
require.NoError(t, err)
return test{
p: p,
token: tok,
sans: []string{"127.0.0.1", "foo", "max@smallstep.com"},
fingerprint: "fingerprint",
}
},
} }
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
ctx := NewContextWithMethod(context.Background(), SignIdentityMethod) ctx := NewContextWithMethod(context.Background(), SignIdentityMethod)
if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err, err.Error()) {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equal(t, tc.code, sc.StatusCode())
} }
assert.HasPrefix(t, err.Error(), tc.err.Error()) assertHasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) { if assert.NotNil(t, opts) {
assert.Equals(t, 10, len(opts)) assert.Len(t, opts, 11)
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *X5C: case *X5C:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeX5C) assert.Equal(t, TypeX5C, v.Type)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equal(t, tc.p.GetName(), v.Name)
assert.Equals(t, v.CredentialID, "") assert.Equal(t, "", v.CredentialID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, v.KeyValuePairs, 0)
case profileLimitDuration: case profileLimitDuration:
assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) assert.Equal(t, tc.p.ctl.Claimer.DefaultTLSCertDuration(), v.def)
claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) assert.Equal(t, claims.chains[0][0].NotAfter, v.notAfter)
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "foo") assert.Equal(t, "foo", string(v))
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *defaultSANsValidator: case *defaultSANsValidator:
assert.Equals(t, v.sans, tc.sans) assert.Equal(t, tc.sans, v.sans)
assert.Equals(t, MethodFromContext(v.ctx), SignIdentityMethod) assert.Equal(t, SignIdentityMethod, MethodFromContext(v.ctx))
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equal(t, tc.p.ctl.Claimer.MinTLSCertDuration(), v.min)
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) assert.Equal(t, tc.p.ctl.Claimer.MaxTLSCertDuration(), v.max)
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equal(t, nil, v.policyEngine)
case *WebhookController: case *WebhookController:
assert.Len(t, 0, v.webhooks) assert.Len(t, v.webhooks, 0)
assert.Equals(t, linkedca.Webhook_X509, v.certType) assert.Equal(t, linkedca.Webhook_X509, v.certType)
assert.Len(t, 2, v.options) assert.Len(t, v.options, 2)
case csrFingerprintValidator:
assert.Equal(t, tc.fingerprint, string(v))
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -523,7 +557,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
@ -533,16 +567,16 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err) require.NoError(t, err)
jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err) require.NoError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Revoke[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Revoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs)) withX5CHdr(certs))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -556,9 +590,9 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equal(t, tc.code, sc.StatusCode())
} }
assert.HasPrefix(t, err.Error(), tc.err.Error()) assertHasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
assert.Nil(t, tc.err) assert.Nil(t, tc.err)
@ -577,12 +611,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test { "fail/renew-disabled": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
@ -591,7 +625,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
} }
@ -607,9 +641,9 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equal(t, tc.code, sc.StatusCode())
} }
assert.HasPrefix(t, err.Error(), tc.err.Error()) assertHasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
assert.Nil(t, tc.err) assert.Nil(t, tc.err)
@ -620,28 +654,30 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
func TestX5C_AuthorizeSSHSign(t *testing.T) { func TestX5C_AuthorizeSSHSign(t *testing.T) {
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err) require.NoError(t, err)
x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err) require.NoError(t, err)
_, fn := mockNow() _, fn := mockNow()
defer fn() defer fn()
type test struct { type test struct {
p *X5C p *X5C
token string token string
claims *x5cPayload claims *x5cPayload
code int fingerprint string
err error count int
code int
err error
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/sshCA-disabled": func(t *testing.T) test { "fail/sshCA-disabled": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
// disable sshCA // disable sshCA
enable := false enable := false
p.Claims = &Claims{EnableSSHCA: &enable} p.Claims = &Claims{EnableSSHCA: &enable}
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
@ -651,7 +687,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
}, },
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
@ -661,11 +697,11 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
}, },
"fail/no-Step-claim": func(t *testing.T) test { "fail/no-Step-claim": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -675,10 +711,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
}, },
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test { "fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
id, err := randutil.ASCII(64) id, err := randutil.ASCII(64)
assert.FatalError(t, err) require.NoError(t, err)
now := time.Now() now := time.Now()
claims := &x5cPayload{ claims := &x5cPayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -693,7 +729,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
Step: &stepPayload{}, Step: &stepPayload{},
} }
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
@ -703,10 +739,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
}, },
"ok/with-claims": func(t *testing.T) test { "ok/with-claims": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
id, err := randutil.ASCII(64) id, err := randutil.ASCII(64)
assert.FatalError(t, err) require.NoError(t, err)
now := time.Now() now := time.Now()
claims := &x5cPayload{ claims := &x5cPayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -719,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
Audience: []string{testAudiences.SSHSign[0]}, Audience: []string{testAudiences.SSHSign[0]},
}, },
Step: &stepPayload{SSH: &SignSSHOptions{ Step: &stepPayload{SSH: &SignSSHOptions{
CertType: SSHHostCert, CertType: SSHUserCert,
KeyID: "foo", KeyID: "foo",
Principals: []string{"max", "mariano", "alan"}, Principals: []string{"max", "mariano", "alan"},
ValidAfter: TimeDuration{d: 5 * time.Minute}, ValidAfter: TimeDuration{d: 5 * time.Minute},
@ -727,19 +763,20 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
}}, }},
} }
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
claims: claims, claims: claims,
token: tok, token: tok,
count: 12,
} }
}, },
"ok/without-claims": func(t *testing.T) test { "ok/without-claims": func(t *testing.T) test {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) require.NoError(t, err)
id, err := randutil.ASCII(64) id, err := randutil.ASCII(64)
assert.FatalError(t, err) require.NoError(t, err)
now := time.Now() now := time.Now()
claims := &x5cPayload{ claims := &x5cPayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -754,11 +791,47 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
Step: &stepPayload{SSH: &SignSSHOptions{}}, Step: &stepPayload{SSH: &SignSSHOptions{}},
} }
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
p: p, p: p,
claims: claims, claims: claims,
token: tok, token: tok,
count: 10,
}
},
"ok/cnf": func(t *testing.T) test {
p, err := generateX5C(nil)
require.NoError(t, err)
id, err := randutil.ASCII(64)
require.NoError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{SSH: &SignSSHOptions{
CertType: SSHHostCert,
Principals: []string{"host.smallstep.com"},
}},
Confirmation: &cnfPayload{
Fingerprint: "fingerprint",
},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
require.NoError(t, err)
return test{
p: p,
claims: claims,
token: tok,
fingerprint: "fingerprint",
count: 10,
} }
}, },
} }
@ -769,9 +842,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equal(t, tc.code, sc.StatusCode())
} }
assert.HasPrefix(t, err.Error(), tc.err.Error()) assertHasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
@ -786,38 +859,34 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
tc.claims.Step.SSH.ValidAfter.t = time.Time{} tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{} tc.claims.Step.SSH.ValidBefore.t = time.Time{}
if firstValidator { if firstValidator {
assert.Equals(t, SignSSHOptions(v), *tc.claims.Step.SSH) assert.Equal(t, *tc.claims.Step.SSH, SignSSHOptions(v))
} else { } else {
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{KeyID: tc.claims.Subject}) assert.Equal(t, SignSSHOptions{KeyID: tc.claims.Subject}, SignSSHOptions(v))
} }
firstValidator = false firstValidator = false
case sshCertValidAfterModifier: case sshCertValidAfterModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix()) assert.Equal(t, tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix(), int64(v))
case sshCertValidBeforeModifier: case sshCertValidBeforeModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix()) assert.Equal(t, tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix(), int64(v))
case *sshLimitDuration: case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equal(t, tc.p.ctl.Claimer, v.Claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) assert.Equal(t, x5cCerts[0].NotAfter, v.NotAfter)
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equal(t, tc.p.ctl.Claimer, v.Claimer)
case *sshNamePolicyValidator: case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine) assert.Nil(t, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine) assert.Nil(t, v.hostPolicyEngine)
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
case *WebhookController: case *WebhookController:
assert.Len(t, 0, v.webhooks) assert.Len(t, v.webhooks, 0)
assert.Equals(t, linkedca.Webhook_SSH, v.certType) assert.Equal(t, linkedca.Webhook_SSH, v.certType)
assert.Len(t, 2, v.options) assert.Len(t, v.options, 2)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }
if tc.claims.Step.SSH.CertType != "" { assert.Equal(t, tc.count, tot)
assert.Equals(t, tot, 12)
} else {
assert.Equals(t, tot, 10)
}
} }
} }
} }

@ -154,12 +154,16 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) {
var ( var (
certOptions []sshutil.Option certOptions []sshutil.Option
mods []provisioner.SSHCertModifier mods []provisioner.SSHCertModifier
validators []provisioner.SSHCertValidator validators []provisioner.SSHCertValidator
keyValidators []provisioner.SSHPublicKeyValidator
) )
// Validate given options. // Validate given key and options
if key == nil {
return nil, nil, errs.BadRequest("ssh public key cannot be nil")
}
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -183,6 +187,10 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi
case provisioner.SSHCertModifier: case provisioner.SSHCertModifier:
mods = append(mods, o) mods = append(mods, o)
// validate the ssh public key
case provisioner.SSHPublicKeyValidator:
keyValidators = append(keyValidators, o)
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
@ -202,6 +210,16 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi
} }
} }
// Validate public key
for _, v := range keyValidators {
if err := v.Valid(key); err != nil {
return nil, nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
)
}
}
// Simulated certificate request with request options. // Simulated certificate request with request options.
cr := sshutil.CertificateRequest{ cr := sshutil.CertificateRequest{
Type: opts.CertType, Type: opts.CertType,

@ -19,6 +19,7 @@ import (
"testing" "testing"
"time" "time"
"go.step.sm/crypto/fingerprint"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/minica" "go.step.sm/crypto/minica"
@ -224,15 +225,6 @@ func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
return hash[:], nil return hash[:], nil
} }
func assertHasPrefix(t *testing.T, s, p string) bool {
if strings.HasPrefix(s, p) {
return true
}
t.Helper()
t.Errorf("%q is not a prefix of %q", p, s)
return false
}
type basicConstraints struct { type basicConstraints struct {
IsCA bool `asn1:"optional"` IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional,default:-1"` MaxPathLen int `asn1:"optional,default:-1"`
@ -249,6 +241,11 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
return nil return nil
} }
func assertHasPrefix(t *testing.T, s, p string) bool {
t.Helper()
return assert.True(t, strings.HasPrefix(s, p), "%q is not a prefix of %q", p, s)
}
func TestAuthority_SignWithContext(t *testing.T) { func TestAuthority_SignWithContext(t *testing.T) {
pub, priv, err := keyutil.GenerateDefaultKeyPair() pub, priv, err := keyutil.GenerateDefaultKeyPair()
require.NoError(t, err) require.NoError(t, err)
@ -605,6 +602,43 @@ ZYtQ9Ot36qc=
code: http.StatusForbidden, code: http.StatusForbidden,
} }
}, },
"fail with cnf": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
auth := testAuthority(t)
auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
auth.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equal(t, crt.Subject.CommonName, "smallstep test")
assert.Equal(t, crt.DNSNames, []string{"test.smallstep.com"})
return nil
},
}
// Create a token with cnf
tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{
"sans": []string{"test.smallstep.com"},
"cnf": map[string]any{"x5rt#S256": "bad-fingerprint"},
})
require.NoError(t, err)
opts, err := auth.Authorize(ctx, tok)
require.NoError(t, err)
return &signTest{
auth: auth,
csr: csr,
extraOpts: opts,
signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
err: errors.New(`certificate request fingerprint does not match "bad-fingerprint"`),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv) csr := getCSR(t, priv)
_a := testAuthority(t) _a := testAuthority(t)
@ -852,6 +886,44 @@ ZYtQ9Ot36qc=
extensionsCount: 6, extensionsCount: 6,
} }
}, },
"ok with cnf": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
fingerprint, err := fingerprint.New(csr.Raw, crypto.SHA256, fingerprint.Base64RawURLFingerprint)
require.NoError(t, err)
auth := testAuthority(t)
auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
auth.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equal(t, crt.Subject.CommonName, "smallstep test")
assert.Equal(t, crt.DNSNames, []string{"test.smallstep.com"})
return nil
},
}
// Create a token with cnf
tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{
"sans": []string{"test.smallstep.com"},
"cnf": map[string]any{"x5rt#S256": fingerprint},
})
require.NoError(t, err)
opts, err := auth.Authorize(ctx, tok)
require.NoError(t, err)
return &signTest{
auth: auth,
csr: csr,
extraOpts: opts,
signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
extensionsCount: 6,
}
},
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {

@ -625,7 +625,7 @@ func TestCARenew(t *testing.T) {
cert, err := x509util.NewCertificate(cr) cert, err := x509util.NewCertificate(cr)
assert.FatalError(t, err) assert.FatalError(t, err)
crt := cert.GetCertificate() crt := cert.GetCertificate()
crt.NotBefore = time.Now() crt.NotBefore = now
crt.NotAfter = leafExpiry crt.NotAfter = leafExpiry
crt, err = x509util.CreateCertificate(crt, intermediateCert, pub, intermediateKey.(crypto.Signer)) crt, err = x509util.CreateCertificate(crt, intermediateCert, pub, intermediateKey.(crypto.Signer))
assert.FatalError(t, err) assert.FatalError(t, err)

Loading…
Cancel
Save