package softcas import ( "bytes" "context" "crypto" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "fmt" "io" "math/big" "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/kms" kmsapi "github.com/smallstep/certificates/kms/apiv1" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) var ( testIntermediatePem = `-----BEGIN CERTIFICATE----- MIIBPjCB8aADAgECAhAk4aPIlsVvQg3gveApc3mIMAUGAytlcDAeMRwwGgYDVQQD ExNTbWFsbHN0ZXAgVW5pdCBUZXN0MB4XDTIwMDkxNjAyMDgwMloXDTMwMDkxNDAy MDgwMlowHjEcMBoGA1UEAxMTU21hbGxzdGVwIFVuaXQgVGVzdDAqMAUGAytlcAMh ANLs3JCzECR29biut0NDsaLnh0BGij5eJx6VkdJPfS/ko0UwQzAOBgNVHQ8BAf8E BAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQUup5qpZFMAFdgK7RB xNzmUaQM8YwwBQYDK2VwA0EAAwcW25E/6bchyKwp3RRK1GXiPMDCc+hsTJxuOLWy YM7ga829dU8X4pRcEEAcBndqCED/502excjEK7U9vCkFCg== -----END CERTIFICATE-----` testIntermediateKeyPem = `-----BEGIN PRIVATE KEY----- MC4CAQAwBQYDK2VwBCIEII9ZckcrDKlbhZKR0jp820Uz6mOMLFsq2JhI+Tl7WJwH -----END PRIVATE KEY-----` ) var ( errTest = errors.New("test error") testIssuer = mustIssuer() testSigner = mustSigner() testTemplate = &x509.Certificate{ Subject: pkix.Name{CommonName: "test.smallstep.com"}, DNSNames: []string{"test.smallstep.com"}, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, PublicKey: testSigner.Public(), SerialNumber: big.NewInt(1234), } testRootTemplate = &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Root CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: testSigner.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, SerialNumber: big.NewInt(1234), } testIntermediateTemplate = &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Intermediate CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: testSigner.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: true, SerialNumber: big.NewInt(1234), } testNow = time.Now() testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) ) type mockKeyManager struct { signer crypto.Signer errGetPublicKey error errCreateKey error errCreatesigner error errClose error } func (m *mockKeyManager) GetPublicKey(req *kmsapi.GetPublicKeyRequest) (crypto.PublicKey, error) { signer := testSigner if m.signer != nil { signer = m.signer } return signer.Public(), m.errGetPublicKey } func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { signer := testSigner if m.signer != nil { signer = m.signer } return &kmsapi.CreateKeyResponse{ PrivateKey: signer, PublicKey: signer.Public(), }, m.errCreateKey } func (m *mockKeyManager) CreateSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer, error) { signer := testSigner if m.signer != nil { signer = m.signer } return signer, m.errCreatesigner } func (m *mockKeyManager) CreateDecrypter(req *kmsapi.CreateDecrypterRequest) (crypto.Decrypter, error) { return nil, nil } func (m *mockKeyManager) Close() error { return m.errClose } type badSigner struct{} func (b *badSigner) Public() crypto.PublicKey { return testSigner.Public() } func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { return nil, fmt.Errorf("💥") } func mockNow(t *testing.T) { tmp := now now = func() time.Time { return testNow } t.Cleanup(func() { now = tmp }) } func mustIssuer() *x509.Certificate { v, err := pemutil.Parse([]byte(testIntermediatePem)) if err != nil { panic(err) } return v.(*x509.Certificate) } func mustSigner() crypto.Signer { v, err := pemutil.Parse([]byte(testIntermediateKeyPem)) if err != nil { panic(err) } return v.(crypto.Signer) } func mustSign(template, parent *x509.Certificate, notBefore, notAfter time.Time) *x509.Certificate { tmpl := *template tmpl.NotBefore = notBefore tmpl.NotAfter = notAfter tmpl.Issuer = parent.Subject cert, err := x509util.CreateCertificate(&tmpl, parent, tmpl.PublicKey, testSigner) if err != nil { panic(err) } return cert } func setTeeReader(t *testing.T, w *bytes.Buffer) { t.Helper() reader := rand.Reader t.Cleanup(func() { rand.Reader = reader }) rand.Reader = io.TeeReader(reader, w) } func TestNew(t *testing.T) { type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want *SoftCAS wantErr bool }{ {"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false}, {"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true}, {"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) } } func TestNew_register(t *testing.T) { newFn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.SoftCAS) if !ok { t.Error("apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.SoftCAS) was not found") return } want := &SoftCAS{ CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner, } got, err := newFn(context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}) if err != nil { t.Errorf("New() error = %v", err) return } if !reflect.DeepEqual(got, want) { t.Errorf("New() = %v, want %v", got, want) } } func TestSoftCAS_CreateCertificate(t *testing.T) { mockNow(t) // Set rand.Reader to EOF buf := new(bytes.Buffer) setTeeReader(t, buf) rand.Reader = buf tmplNotBefore := *testTemplate tmplNotBefore.NotBefore = testNow tmplNotAfter := *testTemplate tmplNotAfter.NotAfter = testNow.Add(24 * time.Hour) tmplWithLifetime := *testTemplate tmplWithLifetime.NotBefore = testNow tmplWithLifetime.NotAfter = testNow.Add(24 * time.Hour) tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil type fields struct { Issuer *x509.Certificate Signer crypto.Signer } type args struct { req *apiv1.CreateCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateResponse wantErr bool }{ {"ok", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with notBefore+notAfter", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ Template: &tmplWithLifetime, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"fail template", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.CreateCertificate() = %v, want %v", got, tt.want) } }) } } func TestSoftCAS_RenewCertificate(t *testing.T) { mockNow(t) // Set rand.Reader to EOF buf := new(bytes.Buffer) setTeeReader(t, buf) rand.Reader = buf tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil type fields struct { Issuer *x509.Certificate Signer crypto.Signer } type args struct { req *apiv1.RenewCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RenewCertificateResponse wantErr bool }{ {"ok", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, } got, err := c.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.RenewCertificate() = %v, want %v", got, tt.want) } }) } } func TestSoftCAS_RevokeCertificate(t *testing.T) { type fields struct { Issuer *x509.Certificate Signer crypto.Signer } type args struct { req *apiv1.RevokeCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RevokeCertificateResponse wantErr bool }{ {"ok", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok no cert", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok empty", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, } got, err := c.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.RevokeCertificate() = %v, want %v", got, tt.want) } }) } } func Test_now(t *testing.T) { t0 := time.Now() t1 := now() if t1.Sub(t0) > time.Second { t.Errorf("now() = %s, want ~%s", t1, t0) } } func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { mockNow(t) type fields struct { Issuer *x509.Certificate Signer crypto.Signer KeyManager kms.KeyManager } type args struct { req *apiv1.CreateCertificateAuthorityRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateAuthorityResponse wantErr bool }{ {"ok root", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testRootTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: "Test Root CA", Certificate: testSignedRootTemplate, PublicKey: testSignedRootTemplate.PublicKey, PrivateKey: testSigner, Signer: testSigner, }, false}, {"ok intermediate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, Signer: testSigner, }, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: "Test Intermediate CA", Certificate: testSignedIntermediateTemplate, CertificateChain: []*x509.Certificate{testSignedRootTemplate}, PublicKey: testSignedIntermediateTemplate.PublicKey, PrivateKey: testSigner, Signer: testSigner, }, false}, {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail lifetime", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, }}, nil, true}, {"fail type", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail parent", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail parent.certificate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Signer: testSigner, }, }}, nil, true}, {"fail parent.signer", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, }, }}, nil, true}, {"fail createKey", fields{nil, nil, &mockKeyManager{errCreateKey: errTest}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail createSigner", fields{nil, nil, &mockKeyManager{errCreatesigner: errTest}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail sign root", fields{nil, nil, &mockKeyManager{signer: &badSigner{}}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail sign intermediate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, Signer: &badSigner{}, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, KeyManager: tt.fields.KeyManager, } got, err := c.CreateCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.CreateCertificateAuthority() = \n%#v, want \n%#v", got, tt.want) } }) } }