From 97165f184478134e44b2e6c8ed206650364bd7e6 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 9 Jul 2021 22:48:03 +0200 Subject: [PATCH] Fix test mocking for CreateCertificate --- acme/db.go | 1 + acme/db/nosql/certificate_test.go | 40 ++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/acme/db.go b/acme/db.go index c4b79a66..67053269 100644 --- a/acme/db.go +++ b/acme/db.go @@ -182,6 +182,7 @@ func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, e return m.MockRet1.(*Certificate), m.MockError } +// GetCertificateBySerial mock func (m *MockDB) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) { if m.MockGetCertificateBySerial != nil { return m.MockGetCertificateBySerial(ctx, serial) diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go index 4ec4589e..f5a8b67f 100644 --- a/acme/db/nosql/certificate_test.go +++ b/acme/db/nosql/certificate_test.go @@ -31,6 +31,7 @@ func TestDB_CreateCertificate(t *testing.T) { err error _id *string } + countOfCmpAndSwapCalls := 0 var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { cert := &acme.Certificate{ @@ -75,18 +76,35 @@ func TestDB_CreateCertificate(t *testing.T) { return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { - *idPtr = string(key) - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - assert.Equals(t, old, nil) + if countOfCmpAndSwapCalls == 0 { + *idPtr = string(key) + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + assert.Equals(t, old, nil) + + dbc := new(dbCert) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.ID, cert.ID) + assert.Equals(t, dbc.AccountID, cert.AccountID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + } + if countOfCmpAndSwapCalls == 1 { + assert.Equals(t, bucket, certBySerialTable) + assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String())) + assert.Equals(t, old, nil) + + dbs := new(dbSerial) + assert.FatalError(t, json.Unmarshal(nu, dbs)) + assert.Equals(t, dbs.Serial, string(key)) + assert.Equals(t, dbs.CertificateID, cert.ID) + + *idPtr = cert.ID + } + + countOfCmpAndSwapCalls += 1 - dbc := new(dbCert) - assert.FatalError(t, json.Unmarshal(nu, dbc)) - assert.Equals(t, dbc.ID, string(key)) - assert.Equals(t, dbc.ID, cert.ID) - assert.Equals(t, dbc.AccountID, cert.AccountID) - assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) - assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, true, nil }, },