smallstep-certificates/acme/db/nosql/certificate_test.go
2021-11-28 21:20:57 +01:00

471 lines
13 KiB
Go

package nosql
import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/crypto/pemutil"
)
func TestDB_CreateCertificate(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
type test struct {
db nosql.DB
cert *acme.Certificate
err error
_id *string
}
var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
cert := &acme.Certificate{
AccountID: "accountID",
OrderID: "orderID",
Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root},
}
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
assert.Equals(t, old, nil)
dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, false, errors.New("force")
},
},
cert: cert,
err: errors.New("error saving acme certificate: force"),
}
},
"ok": func(t *testing.T) test {
cert := &acme.Certificate{
AccountID: "accountID",
OrderID: "orderID",
Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root},
}
var (
id string
idPtr = &id
)
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) {
t.Fail()
}
if bytes.Equal(bucket, certTable) {
*idPtr = string(key)
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
assert.Equals(t, old, nil)
dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
}
if bytes.Equal(bucket, certBySerialTable) {
assert.Equals(t, bucket, certBySerialTable)
assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String()))
assert.Equals(t, old, nil)
dbs := new(dbSerial)
assert.FatalError(t, json.Unmarshal(nu, dbs))
assert.Equals(t, dbs.Serial, string(key))
assert.Equals(t, dbs.CertificateID, cert.ID)
*idPtr = cert.ID
}
return nil, true, nil
},
},
_id: idPtr,
cert: cert,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.ID, *tc._id)
}
}
})
}
}
func TestDB_GetCertificate(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
certID := "certID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, string(key), certID)
return nil, nosqldb.ErrNotFound
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, string(key), certID)
return nil, errors.Errorf("force")
},
},
err: errors.New("error loading certificate certID: force"),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, string(key), certID)
return []byte("foobar"), nil
},
},
err: errors.New("error unmarshaling certificate certID"),
}
},
"fail/parseBundle-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, string(key), certID)
cert := dbCert{
ID: certID,
AccountID: "accountID",
OrderID: "orderID",
Leaf: pem.EncodeToMemory(&pem.Block{
Type: "Public Key",
Bytes: leaf.Raw,
}),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
return b, nil
},
},
err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"),
}
},
"ok": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, string(key), certID)
cert := dbCert{
ID: certID,
AccountID: "accountID",
OrderID: "orderID",
Leaf: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: leaf.Raw,
}),
Intermediates: append(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: inter.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})...),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
return b, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
cert, err := d.GetCertificate(context.Background(), certID)
if err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, cert.ID, certID)
assert.Equals(t, cert.AccountID, "accountID")
assert.Equals(t, cert.OrderID, "orderID")
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
}
})
}
}
func Test_parseBundle(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
var certs []byte
for _, cert := range []*x509.Certificate{leaf, inter, root} {
certs = append(certs, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})...)
}
type test struct {
b []byte
err error
}
var tests = map[string]test{
"fail/bad-type-error": {
b: pem.EncodeToMemory(&pem.Block{
Type: "Public Key",
Bytes: leaf.Raw,
}),
err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"),
},
"fail/bad-pem-error": {
b: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: []byte("foo"),
}),
err: errors.Errorf("error parsing x509 certificate"),
},
"fail/unexpected-data": {
b: append(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: leaf.Raw,
}), []byte("foo")...),
err: errors.Errorf("error decoding PEM: unexpected data"),
},
"ok": {
b: certs,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ret, err := parseBundle(tc.b)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root})
}
}
})
}
}
func TestDB_GetCertificateBySerial(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
certID := "certID"
serial := ""
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return nil, nosqldb.ErrNotFound
}
return nil, errors.New("wrong table")
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial),
}
},
"fail/db-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return nil, errors.New("force")
}
return nil, errors.New("wrong table")
},
},
err: fmt.Errorf("error loading certificate ID for serial %s", serial),
}
},
"fail/unmarshal-dbSerial": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return []byte(`{"serial":malformed!}`), nil
}
return nil, errors.New("wrong table")
},
},
err: fmt.Errorf("error unmarshaling certificate with serial %s", serial),
}
},
"ok": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
certSerial := dbSerial{
Serial: serial,
CertificateID: certID,
}
b, err := json.Marshal(certSerial)
assert.FatalError(t, err)
return b, nil
}
if bytes.Equal(bucket, certTable) {
cert := dbCert{
ID: certID,
AccountID: "accountID",
OrderID: "orderID",
Leaf: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: leaf.Raw,
}),
Intermediates: append(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: inter.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})...),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
return b, nil
}
return nil, errors.New("wrong table")
},
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
cert, err := d.GetCertificateBySerial(context.Background(), serial)
if err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, cert.ID, certID)
assert.Equals(t, cert.AccountID, "accountID")
assert.Equals(t, cert.OrderID, "orderID")
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
}
})
}
}