Fix existing unit tests.

pull/312/head
Mariano Cano 4 years ago
parent 497158d0f6
commit 0c8376a7f6

@ -28,6 +28,7 @@ type MockProvisioner struct {
MgetName func() string
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.ProvisionerOptions
}
// GetName mock
@ -54,6 +55,13 @@ func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
return m.Mret1.(time.Duration)
}
func (m *MockProvisioner) GetOptions() *provisioner.ProvisionerOptions {
if m.MgetOptions != nil {
return m.MgetOptions()
}
return m.Mret1.(*provisioner.ProvisionerOptions)
}
// ContextKey is the key type for storing and searching for ACME request
// essentials in the context of a request.
type ContextKey string

@ -1066,13 +1066,13 @@ func TestOrderFinalize(t *testing.T) {
Subject: pkix.Name{
CommonName: "",
},
DNSNames: []string{"acme.example.com", "step.example.com"},
// DNSNames: []string{"acme.example.com", "step.example.com"},
IPAddresses: []net.IP{net.ParseIP("1.1.1.1")},
}
return test{
o: o,
csr: csr,
err: BadCSRErr(errors.Errorf("CSR contains IP Address SANs, but should only contain DNS Names")),
err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
}
},
"fail/ready/no-emailAddresses": func(t *testing.T) test {
@ -1084,13 +1084,13 @@ func TestOrderFinalize(t *testing.T) {
Subject: pkix.Name{
CommonName: "",
},
DNSNames: []string{"acme.example.com", "step.example.com"},
// DNSNames: []string{"acme.example.com", "step.example.com"},
EmailAddresses: []string{"max@smallstep.com", "mariano@smallstep.com"},
}
return test{
o: o,
csr: csr,
err: BadCSRErr(errors.Errorf("CSR contains Email Address SANs, but should only contain DNS Names")),
err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
}
},
"fail/ready/no-URIs": func(t *testing.T) test {
@ -1104,13 +1104,13 @@ func TestOrderFinalize(t *testing.T) {
Subject: pkix.Name{
CommonName: "",
},
DNSNames: []string{"acme.example.com", "step.example.com"},
URIs: []*url.URL{u},
// DNSNames: []string{"acme.example.com", "step.example.com"},
URIs: []*url.URL{u},
}
return test{
o: o,
csr: csr,
err: BadCSRErr(errors.Errorf("CSR contains URI SANs, but should only contain DNS Names")),
err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
}
},
"fail/ready/provisioner-auth-sign-error": func(t *testing.T) test {
@ -1263,7 +1263,7 @@ func TestOrderFinalize(t *testing.T) {
csr: csr,
sa: &mockSignAuth{
sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, len(signOps), 5)
assert.Equals(t, len(signOps), 6)
return []*x509.Certificate{crt, inter}, nil
},
},
@ -1312,7 +1312,7 @@ func TestOrderFinalize(t *testing.T) {
csr: csr,
sa: &mockSignAuth{
sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, len(signOps), 5)
assert.Equals(t, len(signOps), 6)
return []*x509.Certificate{crt, inter}, nil
},
},
@ -1359,7 +1359,7 @@ func TestOrderFinalize(t *testing.T) {
csr: csr,
sa: &mockSignAuth{
sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, len(signOps), 5)
assert.Equals(t, len(signOps), 6)
return []*x509.Certificate{crt, inter}, nil
},
},

@ -540,11 +540,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1, "foo.local"}, 5, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 9, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 9, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 9, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 5, http.StatusOK, false},
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
@ -574,6 +574,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAWS))
assert.Equals(t, v.Name, tt.aws.GetName())

@ -431,9 +431,9 @@ func TestAzure_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 9, http.StatusOK, false},
{"ok", p1, args{t11}, 4, http.StatusOK, false},
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false},
{"ok", p1, args{t11}, 5, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
@ -458,6 +458,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAzure))
assert.Equals(t, v.Name, tt.azure.GetName())

@ -515,9 +515,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 9, http.StatusOK, false},
{"ok", p3, args{t3}, 4, http.StatusOK, false},
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false},
{"ok", p3, args{t3}, 5, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
@ -547,6 +547,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeGCP))
assert.Equals(t, v.Name, tt.gcp.GetName())

@ -295,9 +295,10 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 6, got)
assert.Len(t, 7, got)
for _, o := range got {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeJWK))
assert.Equals(t, v.Name, tt.prov.GetName())

@ -274,6 +274,7 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
tot := 0
for _, o := range opts {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeK8sSA))
assert.Equals(t, v.Name, tc.p.GetName())
@ -290,7 +291,7 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
}
tot++
}
assert.Equals(t, tot, 4)
assert.Equals(t, tot, 5)
}
}
}

@ -363,6 +363,10 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
}
// Enforce an email claim
if claims.Email == "" {
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign: failed to validate oidc token payload: email not found")
}
signOptions := []SignOption{
// set the key id to the token email
sshCertKeyIDModifier(claims.Email),

@ -179,12 +179,12 @@ func TestOIDC_authorizeToken(t *testing.T) {
assert.FatalError(t, err)
t4, err := generateToken("subject", issuer, p3.ClientID, "foo@smallstep.com", []string{}, time.Now(), &keys.Keys[2])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", issuer, p3.ClientID, "", []string{}, time.Now(), &keys.Keys[2])
t5, err := generateToken("subject", issuer, p3.ClientID, "", []string{}, time.Now(), &keys.Keys[2])
assert.FatalError(t, err)
// Invalid email
failDomain, err := generateToken("subject", issuer, p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[2])
assert.FatalError(t, err)
// Invalid tokens
parts := strings.Split(t1, ".")
key, err := generateJSONWebKey()
@ -226,7 +226,7 @@ func TestOIDC_authorizeToken(t *testing.T) {
{"ok tenantid", p2, args{t2}, http.StatusOK, tenantIssuer, false},
{"ok admin", p3, args{t3}, http.StatusOK, issuer, false},
{"ok domain", p3, args{t4}, http.StatusOK, issuer, false},
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, "", true},
{"ok no email", p3, args{t5}, http.StatusOK, issuer, false},
{"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, "", true},
{"fail-key", p1, args{failKey}, http.StatusUnauthorized, "", true},
{"fail-token", p1, args{failTok}, http.StatusUnauthorized, "", true},
@ -290,8 +290,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
// No email
noEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
@ -306,7 +306,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
}{
{"ok1", p1, args{t1}, http.StatusOK, false},
{"admin", p3, args{okAdmin}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
{"no-email", p3, args{noEmail}, http.StatusOK, false},
{"bad-token", p3, args{"foobar"}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -323,12 +324,13 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
} else {
if assert.NotNil(t, got) {
if tt.name == "admin" {
assert.Len(t, 4, got)
assert.Len(t, 5, got)
} else {
assert.Len(t, 5, got)
}
for _, o := range got {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeOIDC))
assert.Equals(t, v.Name, tt.prov.GetName())
@ -514,7 +516,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
// Empty email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)

@ -22,6 +22,7 @@ func (fn certificateOptionsFunc) Options(so Options) []x509util.Option {
// ProvisionerOptions are a collection of custom options that can be added to
// each provisioner.
// nolint:golint
type ProvisionerOptions struct {
// Template contains a X.509 certificate template. It can be a JSON template
// escaped in a string or it can be also encoded in base64.

@ -17,6 +17,9 @@ import (
"golang.org/x/crypto/ed25519"
)
// DefaultCertValidity is the default validity for a certificate if none is specified.
const DefaultCertValidity = 24 * time.Hour
// Options contains the options that can be passed to the Sign method. Backdate
// is automatically filled and can only be configured in the CA.
type Options struct {
@ -277,7 +280,11 @@ func (v profileDefaultDuration) Modify(cert *x509.Certificate, so Options) error
}
notAfter := so.NotAfter.RelativeTime(notBefore)
if notAfter.IsZero() {
notAfter = notBefore.Add(time.Duration(v))
if v != 0 {
notAfter = notBefore.Add(time.Duration(v))
} else {
notAfter = notBefore.Add(DefaultCertValidity)
}
}
cert.NotBefore = notBefore.Add(backdate)

@ -13,7 +13,6 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util"
)
func Test_emailOnlyIdentity_Valid(t *testing.T) {
@ -562,15 +561,13 @@ func Test_forceCN_Option(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
if err := tt.fcn.Option(tt.so)(prof); err != nil {
if err := tt.fcn.Modify(tt.cert, tt.so); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
if assert.Nil(t, tt.err) {
tt.valid(prof.Subject())
tt.valid(tt.cert)
}
}
})
@ -661,10 +658,9 @@ func Test_profileDefaultDuration_Option(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
assert.FatalError(t, tt.pdd.Option(tt.so)(prof), "unexpected error")
tt.valid(prof.Subject())
assert.FatalError(t, tt.pdd.Modify(tt.cert, tt.so), "unexpected error")
time.Sleep(1 * time.Nanosecond)
tt.valid(tt.cert)
})
}
}
@ -702,10 +698,8 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
assert.FatalError(t, newProvisionerExtensionOption(TypeJWK, "foo", "bar", "baz", "zap").Option(Options{})(prof))
tt.valid(prof.Subject())
assert.FatalError(t, newProvisionerExtensionOption(TypeJWK, "foo", "bar", "baz", "zap").Modify(tt.cert, Options{}))
tt.valid(tt.cert)
})
}
}
@ -803,15 +797,13 @@ func Test_profileLimitDuration_Option(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
if err := tt.pld.Option(tt.so)(prof); err != nil {
if err := tt.pld.Modify(tt.cert, tt.so); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
if assert.Nil(t, tt.err) {
tt.valid(prof.Subject())
tt.valid(tt.cert)
}
}
})

@ -463,9 +463,10 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
assert.Equals(t, len(opts), 6)
assert.Equals(t, len(opts), 7)
for _, o := range opts {
switch v := o.(type) {
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeX5C))
assert.Equals(t, v.Name, tc.p.GetName())

@ -2,11 +2,13 @@ package ca
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/json"
"encoding/pem"
"fmt"
@ -52,6 +54,22 @@ func getCSR(priv interface{}) (*x509.CertificateRequest, error) {
return x509.ParseCertificateRequest(csrBytes)
}
func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
b, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
return nil, errors.Wrap(err, "error marshaling public key")
}
info := struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}{}
if _, err = asn1.Unmarshal(b, &info); err != nil {
return nil, errors.Wrap(err, "error unmarshaling public key")
}
hash := sha1.Sum(info.SubjectPublicKey.Bytes)
return hash[:], nil
}
func TestMain(m *testing.M) {
DisableIdentity = true
os.Exit(m.Run())
@ -299,10 +317,9 @@ ZEp7knvU2psWRw==
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
subjectKeyID, err := generateSubjectKeyID(pub)
assert.FatalError(t, err)
hash := sha1.Sum(pubBytes)
assert.Equals(t, leaf.SubjectKeyId, hash[:])
assert.Equals(t, leaf.SubjectKeyId, subjectKeyID)
assert.Equals(t, leaf.AuthorityKeyId, intermediateIdentity.Crt.SubjectKeyId)
@ -641,11 +658,9 @@ func TestCARenew(t *testing.T) {
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"funk"})
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
subjectKeyID, err := generateSubjectKeyID(pub)
assert.FatalError(t, err)
hash := sha1.Sum(pubBytes)
assert.Equals(t, leaf.SubjectKeyId, hash[:])
assert.Equals(t, leaf.SubjectKeyId, subjectKeyID)
assert.Equals(t, leaf.AuthorityKeyId, intermediateIdentity.Crt.SubjectKeyId)
realIntermediate, err := x509.ParseCertificate(intermediateIdentity.Crt.Raw)

@ -9,6 +9,7 @@ import (
// List of signature algorithms, all of them have values in upper case to match
// them with the string representation.
// nolint:golint
const (
MD2_RSA = "MD2-RSA"
MD5_RSA = "MD5-RSA"

Loading…
Cancel
Save