diff --git a/authority/authorize_test.go b/authority/authorize_test.go index bec34fd6..975ffc01 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -88,6 +88,39 @@ func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) { + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("kid", jwk.KeyID) + + for k, v := range extraHeaders { + so.WithHeader(jose.HeaderKey(k), v) + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + + iat := time.Now() + claims := jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + } + + return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize() +} + func TestAuthority_authorizeToken(t *testing.T) { a := testAuthority(t) @@ -491,7 +524,7 @@ func TestAuthority_authorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, 10, len(got)) // number of provisioner.SignOptions returned + assert.Equals(t, 11, len(got)) // number of provisioner.SignOptions returned } } }) diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 3a7512b8..2f73c8e5 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -19,8 +19,9 @@ import ( // jwtPayload extends jwt.Claims with step attributes. type jwtPayload struct { jose.Claims - SANs []string `json:"sans,omitempty"` - Step *stepPayload `json:"step,omitempty"` + SANs []string `json:"sans,omitempty"` + Step *stepPayload `json:"step,omitempty"` + Confirmation *cnfPayload `json:"cnf,omitempty"` } type stepPayload struct { @@ -28,6 +29,10 @@ type stepPayload struct { RA *RAInfo `json:"ra,omitempty"` } +type cnfPayload struct { + Kid string `json:"kid,omitempty"` +} + // JWK is the default provisioner, an entity that can sign tokens necessary for // signature requests. type JWK struct { @@ -183,6 +188,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } } + // Check the fingerprint of the certificate request if given. + var fingerprint string + if claims.Confirmation != nil { + fingerprint = claims.Confirmation.Kid + } + return []SignOption{ self, templateOptions, @@ -190,6 +201,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID).WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators + fingerprintValidator(fingerprint), commonNameSliceValidator(append([]string{claims.Subject}, claims.SANs...)), defaultPublicKeyValidator{}, newDefaultSANsValidator(ctx, claims.SANs), @@ -229,6 +241,11 @@ func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e sshCertOptionsValidator(SignSSHOptions{KeyID: claims.Subject}), } + // Check the fingerprint of the certificate request if given. + if claims.Confirmation != nil && claims.Confirmation.Kid != "" { + signOptions = append(signOptions, sshFingerprintValidator(claims.Confirmation.Kid)) + } + // Default template attributes. certType := sshutil.UserCert keyID := claims.Subject diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 794fe1ea..2471130a 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -13,7 +13,9 @@ import ( "testing" "time" + "go.step.sm/crypto/fingerprint" "go.step.sm/crypto/jose" + "golang.org/x/crypto/ssh" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" @@ -247,6 +249,9 @@ func TestJWK_AuthorizeSign(t *testing.T) { t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{}, time.Now(), key1) assert.FatalError(t, err) + t3, err := generateCustomToken("subject", p1.Name, testAudiences.Sign[0], key1, nil, map[string]any{"cnf": map[string]any{"kid": "fingerprint"}}) + assert.FatalError(t, err) + // invalid signature failSig := t1[0 : len(t1)-2] @@ -254,12 +259,13 @@ func TestJWK_AuthorizeSign(t *testing.T) { token string } tests := []struct { - name string - prov *JWK - args args - code int - err error - sans []string + name string + prov *JWK + args args + code int + err error + sans []string + fingerprint string }{ { name: "fail-signature", @@ -284,6 +290,15 @@ func TestJWK_AuthorizeSign(t *testing.T) { err: nil, sans: []string{"subject"}, }, + { + name: "ok-cnf", + prov: p1, + args: args{t3}, + code: http.StatusOK, + err: nil, + sans: []string{"subject"}, + fingerprint: "fingerprint", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -297,7 +312,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } } else { if assert.NotNil(t, got) { - assert.Equals(t, 10, len(got)) + assert.Equals(t, 11, len(got)) for _, o := range got { switch v := o.(type) { case *JWK: @@ -321,6 +336,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: + case fingerprintValidator: + assert.Equals(t, tt.fingerprint, string(v)) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } @@ -393,17 +410,6 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { jwk, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) - iss, aud := p1.Name, testAudiences.SSHSign[0] - - t1, err := generateSimpleSSHUserToken(iss, aud, jwk) - assert.FatalError(t, err) - - t2, err := generateSimpleSSHHostToken(iss, aud, jwk) - assert.FatalError(t, err) - - // invalid signature - failSig := t1[0 : len(t1)-2] - key, err := generateJSONWebKey() assert.FatalError(t, err) @@ -417,6 +423,39 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) + // Calculate fingerprint + sshPub, err := ssh.NewPublicKey(pub) + assert.FatalError(t, err) + fp, err := fingerprint.New(sshPub.Marshal(), crypto.SHA256, fingerprint.Base64RawURLFingerprint) + assert.FatalError(t, err) + + iss, aud := p1.Name, testAudiences.SSHSign[0] + + t1, err := generateSimpleSSHUserToken(iss, aud, jwk) + assert.FatalError(t, err) + + t2, err := generateSimpleSSHHostToken(iss, aud, jwk) + assert.FatalError(t, err) + + t3, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{ + "step": map[string]any{ + "ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}}, + }, + "cnf": map[string]any{"kid": fp}, + }) + assert.FatalError(t, err) + + t4, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{ + "step": map[string]any{ + "ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}}, + }, + "cnf": map[string]any{"kid": "bad-fingerprint"}, + }) + assert.FatalError(t, err) + + // invalid signature + failSig := t1[0 : len(t1)-2] + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ @@ -451,9 +490,11 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { {"host-type", p1, args{t2, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-principals", p1, args{t2, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-options", p1, args{t2, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"host-cnf", p1, args{t3, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"fail-sshCA-disabled", p2, args{"foo", SignSSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false}, {"fail-signature", p1, args{failSig, SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, - {"rail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, + {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, + {"fail-cnf", p1, args{t4, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusUnauthorized, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index fec9b9f6..62017fad 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -5,7 +5,10 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "crypto/sha256" + "crypto/subtle" "crypto/x509" + "encoding/base64" "encoding/json" "net" "net/http" @@ -492,3 +495,21 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption cert.ExtraExtensions = append(cert.ExtraExtensions, ext) return nil } + +// fingerprintValidator is a CertificateRequestValidator that checks the +// fingerprint of the certificate with the provided one. +type fingerprintValidator string + +func (s fingerprintValidator) Valid(cr *x509.CertificateRequest) error { + if s != "" { + expected, err := base64.RawURLEncoding.DecodeString(string(s)) + if err != nil { + return errs.ForbiddenErr(err, "error decoding fingerprint") + } + sum := sha256.Sum256(cr.Raw) + if subtle.ConstantTimeCompare(expected, sum[:]) != 1 { + return errs.Forbidden("certificate request fingerprint does not match %q", s) + } + } + return nil +} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index ee74ded3..648b4672 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -2,6 +2,9 @@ package provisioner import ( "crypto/rsa" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" "encoding/binary" "encoding/json" "fmt" @@ -44,6 +47,13 @@ type SSHCertOptionsValidator interface { Valid(got SignSSHOptions) error } +// SSHPublicKeyValidator is the interface used to validate the public key of an +// SSH certificate. +type SSHPublicKeyValidator interface { + SignOption + Valid(got ssh.PublicKey) error +} + // SignSSHOptions contains the options that can be passed to the SignSSH method. type SignSSHOptions struct { CertType string `json:"certType"` @@ -419,6 +429,24 @@ func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) } } +// sshFingerprintValidator is a SSHPublicKeyValidator that checks the +// fingerprint of the public key with the provided one. +type sshFingerprintValidator string + +func (s sshFingerprintValidator) Valid(key ssh.PublicKey) error { + if s != "" { + expected, err := base64.RawURLEncoding.DecodeString(string(s)) + if err != nil { + return errs.ForbiddenErr(err, "error decoding fingerprint") + } + sum := sha256.Sum256(key.Marshal()) + if subtle.ConstantTimeCompare(expected, sum[:]) != 1 { + return errs.Forbidden("ssh public key fingerprint does not match %q", s) + } + } + return nil +} + // sshCertTypeUInt32 func sshCertTypeUInt32(ct string) uint32 { switch ct { diff --git a/authority/provisioner/ssh_test.go b/authority/provisioner/ssh_test.go index 6ad71459..8d165e92 100644 --- a/authority/provisioner/ssh_test.go +++ b/authority/provisioner/ssh_test.go @@ -51,6 +51,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []Si var mods []SSHCertModifier var certOptions []sshutil.Option var validators []SSHCertValidator + var keyValidators []SSHPublicKeyValidator for _, op := range signOpts { switch o := op.(type) { @@ -71,11 +72,19 @@ func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []Si } // call webhooks case *WebhookController: + case sshFingerprintValidator: + keyValidators = append(keyValidators, o) default: return nil, fmt.Errorf("signSSH: invalid extra option type %T", o) } } + for _, v := range keyValidators { + if err := v.Valid(pub); err != nil { + return nil, err + } + } + // Simulated certificate request with request options. cr := sshutil.CertificateRequest{ Type: opts.CertType, diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index a599a835..18538d0c 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -765,6 +765,37 @@ func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jw return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) { + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("kid", jwk.KeyID) + + for k, v := range extraHeaders { + so.WithHeader(jose.HeaderKey(k), v) + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + iat := time.Now() + claims := jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + } + return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize() +} + func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 9b1f2b08..a1f6a497 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -21,9 +21,10 @@ import ( // x5cPayload extends jwt.Claims with step attributes. type x5cPayload struct { jose.Claims - SANs []string `json:"sans,omitempty"` - Step *stepPayload `json:"step,omitempty"` - chains [][]*x509.Certificate + SANs []string `json:"sans,omitempty"` + Step *stepPayload `json:"step,omitempty"` + Confirmation *cnfPayload `json:"cnf,omitempty"` + chains [][]*x509.Certificate } // X5C is the default provisioner, an entity that can sign tokens necessary for @@ -233,6 +234,12 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } } + // Check the fingerprint of the certificate request if given. + var fingerprint string + if claims.Confirmation != nil { + fingerprint = claims.Confirmation.Kid + } + return []SignOption{ self, templateOptions, @@ -243,6 +250,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er x5cLeaf.NotBefore, x5cLeaf.NotAfter, }, // validators + fingerprintValidator(fingerprint), commonNameValidator(claims.Subject), newDefaultSANsValidator(ctx, claims.SANs), defaultPublicKeyValidator{}, @@ -285,6 +293,11 @@ func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e sshCertOptionsValidator(SignSSHOptions{KeyID: claims.Subject}), } + // Check the fingerprint of the certificate request if given. + if claims.Confirmation != nil && claims.Confirmation.Kid != "" { + signOptions = append(signOptions, sshFingerprintValidator(claims.Confirmation.Kid)) + } + // Default template attributes. certType := sshutil.UserCert keyID := claims.Subject diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 22545446..eb3946dd 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -3,6 +3,7 @@ package provisioner import ( "context" "crypto/x509" + "encoding/base64" "errors" "fmt" "net/http" @@ -413,11 +414,12 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) type test struct { - p *X5C - token string - code int - err error - sans []string + p *X5C + token string + code int + err error + sans []string + fingerprint string } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { @@ -456,13 +458,36 @@ func TestX5C_AuthorizeSign(t *testing.T) { sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, } }, + "ok/cnf": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + + x5c := make([]string, len(certs)) + for i, cert := range certs { + x5c[i] = base64.StdEncoding.EncodeToString(cert.Raw) + } + extraHeaders := map[string]any{"x5c": x5c} + extraClaims := map[string]any{ + "sans": []string{"127.0.0.1", "foo", "max@smallstep.com"}, + "cnf": map[string]any{"kid": "fingerprint"}, + } + + tok, err := generateCustomToken("foo", p.GetName(), testAudiences.Sign[0], jwk, extraHeaders, extraClaims) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, + fingerprint: "fingerprint", + } + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) ctx := NewContextWithMethod(context.Background(), SignIdentityMethod) if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil { - if assert.NotNil(t, tc.err) { + if assert.NotNil(t, tc.err, err.Error()) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) @@ -472,7 +497,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - assert.Equals(t, 10, len(opts)) + assert.Equals(t, 11, len(opts)) for _, o := range opts { switch v := o.(type) { case *X5C: @@ -502,6 +527,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Len(t, 0, v.webhooks) assert.Equals(t, linkedca.Webhook_X509, v.certType) assert.Len(t, 2, v.options) + case fingerprintValidator: + assert.Equals(t, tc.fingerprint, string(v)) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } @@ -627,11 +654,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { _, fn := mockNow() defer fn() type test struct { - p *X5C - token string - claims *x5cPayload - code int - err error + p *X5C + token string + claims *x5cPayload + fingerprint string + count int + code int + err error } tests := map[string]func(*testing.T) test{ "fail/sshCA-disabled": func(t *testing.T) test { @@ -719,7 +748,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { Audience: []string{testAudiences.SSHSign[0]}, }, Step: &stepPayload{SSH: &SignSSHOptions{ - CertType: SSHHostCert, + CertType: SSHUserCert, KeyID: "foo", Principals: []string{"max", "mariano", "alan"}, ValidAfter: TimeDuration{d: 5 * time.Minute}, @@ -732,6 +761,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { p: p, claims: claims, token: tok, + count: 12, } }, "ok/without-claims": func(t *testing.T) test { @@ -759,6 +789,42 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { p: p, claims: claims, token: tok, + count: 10, + } + }, + "ok/cnf": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + + id, err := randutil.ASCII(64) + assert.FatalError(t, err) + now := time.Now() + claims := &x5cPayload{ + Claims: jose.Claims{ + ID: id, + Subject: "foo", + Issuer: p.GetName(), + IssuedAt: jose.NewNumericDate(now), + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{testAudiences.SSHSign[0]}, + }, + Step: &stepPayload{SSH: &SignSSHOptions{ + CertType: SSHHostCert, + Principals: []string{"host.smallstep.com"}, + }}, + Confirmation: &cnfPayload{ + Kid: "fingerprint", + }, + } + tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) + assert.FatalError(t, err) + return test{ + p: p, + claims: claims, + token: tok, + fingerprint: "fingerprint", + count: 11, } }, } @@ -808,16 +874,14 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Len(t, 0, v.webhooks) assert.Equals(t, linkedca.Webhook_SSH, v.certType) assert.Len(t, 2, v.options) + case sshFingerprintValidator: + assert.Equals(t, tc.fingerprint, string(v)) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } - if len(tc.claims.Step.SSH.CertType) > 0 { - assert.Equals(t, tot, 12) - } else { - assert.Equals(t, tot, 10) - } + assert.Equals(t, tc.count, tot) } } } diff --git a/authority/ssh.go b/authority/ssh.go index f9371d60..868dd013 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -148,12 +148,16 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (* // SignSSH creates a signed SSH certificate with the given public key and options. func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var ( - certOptions []sshutil.Option - mods []provisioner.SSHCertModifier - validators []provisioner.SSHCertValidator + certOptions []sshutil.Option + mods []provisioner.SSHCertModifier + validators []provisioner.SSHCertValidator + keyValidators []provisioner.SSHPublicKeyValidator ) - // Validate given options. + // Validate given key and options + if key == nil { + return nil, errs.BadRequest("ssh public key cannot be nil") + } if err := opts.Validate(); err != nil { return nil, err } @@ -177,6 +181,10 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision case provisioner.SSHCertModifier: mods = append(mods, o) + // validate the ssh public key + case provisioner.SSHPublicKeyValidator: + keyValidators = append(keyValidators, o) + // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) @@ -196,6 +204,16 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } } + // Validate public key + for _, v := range keyValidators { + if err := v.Valid(key); err != nil { + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, err.Error()), + errs.WithKeyVal("signOptions", signOpts), + ) + } + } + // Simulated certificate request with request options. cr := sshutil.CertificateRequest{ Type: opts.CertType, diff --git a/authority/tls_test.go b/authority/tls_test.go index efcb78f8..f0192fea 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "go.step.sm/crypto/fingerprint" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" @@ -593,6 +594,43 @@ ZYtQ9Ot36qc= code: http.StatusForbidden, } }, + "fail with cnf": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + + auth := testAuthority(t) + auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + auth.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + assert.Equals(t, crt.DNSNames, []string{"test.smallstep.com"}) + return nil + }, + } + + // Create a token with cnf + tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{ + "sans": []string{"test.smallstep.com"}, + "cnf": map[string]any{"kid": "bad-fingerprint"}, + }) + assert.FatalError(t, err) + + opts, err := auth.Authorize(ctx, tok) + assert.FatalError(t, err) + + return &signTest{ + auth: auth, + csr: csr, + extraOpts: opts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + err: errors.New(`certificate request fingerprint does not match "bad-fingerprint"`), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) @@ -840,6 +878,44 @@ ZYtQ9Ot36qc= extensionsCount: 6, } }, + "ok with cnf": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + fingerprint, err := fingerprint.New(csr.Raw, crypto.SHA256, fingerprint.Base64RawURLFingerprint) + assert.FatalError(t, err) + + auth := testAuthority(t) + auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + auth.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + assert.Equals(t, crt.DNSNames, []string{"test.smallstep.com"}) + return nil + }, + } + + // Create a token with cnf + tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{ + "sans": []string{"test.smallstep.com"}, + "cnf": map[string]any{"kid": fingerprint}, + }) + assert.FatalError(t, err) + + opts, err := auth.Authorize(ctx, tok) + assert.FatalError(t, err) + + return &signTest{ + auth: auth, + csr: csr, + extraOpts: opts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, + } + }, } for name, genTestCase := range tests {