[acme db interface] unit tests for challenge nosql db

pull/496/head
max furman 3 years ago
parent 4b1dda5bb6
commit 206909b12e

@ -10,7 +10,7 @@ import (
type Authorization struct {
Identifier Identifier `json:"identifier"`
Status Status `json:"status"`
Expires time.Time `json:"expires"`
ExpiresAt time.Time `json:"expires"`
Challenges []*Challenge `json:"challenges"`
Wildcard bool `json:"wildcard"`
ID string `json:"-"`
@ -39,7 +39,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
return nil
case StatusPending:
// check expiry
if now.After(az.Expires) {
if now.After(az.ExpiresAt) {
az.Status = StatusInvalid
break
}

@ -22,16 +22,16 @@ import (
// Challenge represents an ACME response Challenge type.
type Challenge struct {
Type string `json:"type"`
Status Status `json:"status"`
Token string `json:"token"`
Validated string `json:"validated,omitempty"`
URL string `json:"url"`
Error *Error `json:"error,omitempty"`
ID string `json:"-"`
AuthzID string `json:"-"`
AccountID string `json:"-"`
Value string `json:"-"`
Type string `json:"type"`
Status Status `json:"status"`
Token string `json:"token"`
ValidatedAt string `json:"validated,omitempty"`
URL string `json:"url"`
Error *Error `json:"error,omitempty"`
ID string `json:"-"`
AuthzID string `json:"-"`
AccountID string `json:"-"`
Value string `json:"-"`
}
// ToLog enables response logging.
@ -97,7 +97,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
// Update and store the challenge.
ch.Status = StatusValid
ch.Error = nil
ch.Validated = clock.Now().Format(time.RFC3339)
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "error updating challenge")
@ -175,7 +175,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
ch.Status = StatusValid
ch.Error = nil
ch.Validated = clock.Now().Format(time.RFC3339)
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge")
@ -231,7 +231,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
// Update and store the challenge.
ch.Status = StatusValid
ch.Error = nil
ch.Validated = clock.Now().UTC().Format(time.RFC3339)
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "error updating challenge")

@ -18,10 +18,10 @@ type dbAuthz struct {
AccountID string `json:"accountID"`
Identifier acme.Identifier `json:"identifier"`
Status acme.Status `json:"status"`
Expires time.Time `json:"expires"`
ExpiresAt time.Time `json:"expiresAt"`
Challenges []string `json:"challenges"`
Wildcard bool `json:"wildcard"`
Created time.Time `json:"created"`
CreatedAt time.Time `json:"createdAt"`
Error *acme.Error `json:"error"`
}
@ -66,7 +66,7 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat
Status: dbaz.Status,
Challenges: chs,
Wildcard: dbaz.Wildcard,
Expires: dbaz.Expires,
ExpiresAt: dbaz.ExpiresAt,
ID: dbaz.ID,
}, nil
}
@ -90,8 +90,8 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e
ID: az.ID,
AccountID: az.AccountID,
Status: acme.StatusPending,
Created: now,
Expires: now.Add(defaultExpiryDuration),
CreatedAt: now,
ExpiresAt: now.Add(defaultExpiryDuration),
Identifier: az.Identifier,
Challenges: chIDs,
Wildcard: az.Wildcard,

@ -14,7 +14,7 @@ import (
type dbCert struct {
ID string `json:"id"`
Created time.Time `json:"created"`
CreatedAt time.Time `json:"createdAt"`
AccountID string `json:"accountID"`
OrderID string `json:"orderID"`
Leaf []byte `json:"leaf"`
@ -47,7 +47,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
OrderID: cert.OrderID,
Leaf: leaf,
Intermediates: intermediates,
Created: time.Now().UTC(),
CreatedAt: time.Now().UTC(),
}
return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
}
@ -57,7 +57,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
b, err := db.db.Get(certTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, errors.Wrapf(err, "certificate %s not found", id)
return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading certificate %s", id)
}

@ -34,7 +34,7 @@ func TestDB_CreateCertificate(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
cert := &acme.Certificate{
AccountID: "accounttID",
AccountID: "accountID",
OrderID: "orderID",
Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root},
@ -48,12 +48,11 @@ func TestDB_CreateCertificate(t *testing.T) {
dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.FatalError(t, err)
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, false, errors.New("force")
},
},
@ -63,7 +62,7 @@ func TestDB_CreateCertificate(t *testing.T) {
},
"ok": func(t *testing.T) test {
cert := &acme.Certificate{
AccountID: "accounttID",
AccountID: "accountID",
OrderID: "orderID",
Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root},
@ -83,12 +82,11 @@ func TestDB_CreateCertificate(t *testing.T) {
dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.FatalError(t, err)
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, true, nil
},
},
@ -124,8 +122,9 @@ func TestDB_GetCertificate(t *testing.T) {
certID := "certID"
type test struct {
db nosql.DB
err error
db nosql.DB
err error
acmeErr *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
@ -138,7 +137,7 @@ func TestDB_GetCertificate(t *testing.T) {
return nil, nosqldb.ErrNotFound
},
},
err: errors.New("certificate certID not found"),
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
@ -182,7 +181,7 @@ func TestDB_GetCertificate(t *testing.T) {
Type: "Public Key",
Bytes: leaf.Raw,
}),
Created: clock.Now(),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
@ -215,7 +214,7 @@ func TestDB_GetCertificate(t *testing.T) {
Type: "CERTIFICATE",
Bytes: root.Raw,
})...),
Created: clock.Now(),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
@ -232,8 +231,19 @@ func TestDB_GetCertificate(t *testing.T) {
db := DB{db: tc.db}
cert, err := db.GetCertificate(context.Background(), certID)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else {
if assert.Nil(t, tc.err) {

@ -10,18 +10,17 @@ import (
"github.com/smallstep/nosql"
)
// dbChallenge is the base Challenge type that others build from.
type dbChallenge struct {
ID string `json:"id"`
AccountID string `json:"accountID"`
AuthzID string `json:"authzID"`
Type string `json:"type"`
Status acme.Status `json:"status"`
Token string `json:"token"`
Value string `json:"value"`
Validated string `json:"validated"`
Created time.Time `json:"created"`
Error *acme.Error `json:"error"`
ID string `json:"id"`
AccountID string `json:"accountID"`
AuthzID string `json:"authzID"`
Type string `json:"type"`
Status acme.Status `json:"status"`
Token string `json:"token"`
Value string `json:"value"`
ValidatedAt string `json:"validatedAt"`
CreatedAt time.Time `json:"createdAt"`
Error *acme.Error `json:"error"`
}
func (dbc *dbChallenge) clone() *dbChallenge {
@ -32,9 +31,9 @@ func (dbc *dbChallenge) clone() *dbChallenge {
func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
data, err := db.db.Get(challengeTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, errors.Wrapf(err, "challenge %s not found", id)
return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading challenge %s", id)
return nil, errors.Wrapf(err, "error loading acme challenge %s", id)
}
dbch := new(dbChallenge)
@ -60,7 +59,7 @@ func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
Value: ch.Value,
Status: acme.StatusPending,
Token: ch.Token,
Created: clock.Now(),
CreatedAt: clock.Now(),
Type: ch.Type,
}
@ -76,22 +75,21 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall
}
ch := &acme.Challenge{
Type: dbch.Type,
Status: dbch.Status,
Token: dbch.Token,
ID: dbch.ID,
AuthzID: dbch.AuthzID,
Error: dbch.Error,
Validated: dbch.Validated,
ID: dbch.ID,
AccountID: dbch.AccountID,
AuthzID: dbch.AuthzID,
Type: dbch.Type,
Value: dbch.Value,
Status: dbch.Status,
Token: dbch.Token,
Error: dbch.Error,
ValidatedAt: dbch.ValidatedAt,
}
return ch, nil
}
// UpdateChallenge updates an ACME challenge type in the database.
func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
if len(ch.ID) == 0 {
return errors.New("id cannot be empty")
}
old, err := db.getDBChallenge(ctx, ch.ID)
if err != nil {
return err
@ -99,10 +97,10 @@ func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
nu := old.clone()
// These should be the only values chaning in an Update request.
// These should be the only values changing in an Update request.
nu.Status = ch.Status
nu.Error = ch.Error
nu.Validated = ch.Validated
nu.ValidatedAt = ch.ValidatedAt
return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
}

@ -0,0 +1,477 @@
package nosql
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
)
func TestDB_getDBChallenge(t *testing.T) {
chID := "chID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
dbc *dbChallenge
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, nosqldb.ErrNotFound
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, errors.New("force")
},
},
err: errors.New("error loading acme challenge chID: force"),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return []byte("foo"), nil
},
},
err: errors.New("error unmarshaling dbChallenge"),
}
},
"ok": func(t *testing.T) test {
dbc := &dbChallenge{
ID: chID,
AccountID: "accountID",
AuthzID: "authzID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
CreatedAt: clock.Now(),
ValidatedAt: "foobar",
Error: acme.NewErrorISE("force"),
}
b, err := json.Marshal(dbc)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
},
dbc: dbc,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err, tc.acmeErr.Err)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID)
assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
}
})
}
}
func TestDB_CreateChallenge(t *testing.T) {
type test struct {
db nosql.DB
ch *acme.Challenge
err error
_id *string
}
var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
ch := &acme.Challenge{
AccountID: "accountID",
AuthzID: "authzID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
}
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), ch.ID)
assert.Equals(t, old, nil)
dbc := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.AccountID, ch.AccountID)
assert.Equals(t, dbc.AuthzID, ch.AuthzID)
assert.Equals(t, dbc.Type, ch.Type)
assert.Equals(t, dbc.Status, ch.Status)
assert.Equals(t, dbc.Token, ch.Token)
assert.Equals(t, dbc.Value, ch.Value)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, false, errors.New("force")
},
},
ch: ch,
err: errors.New("error saving acme challenge: force"),
}
},
"ok": func(t *testing.T) test {
var (
id string
idPtr = &id
ch = &acme.Challenge{
AccountID: "accountID",
AuthzID: "authzID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
}
)
return test{
ch: ch,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
*idPtr = string(key)
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), ch.ID)
assert.Equals(t, old, nil)
dbc := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.AccountID, ch.AccountID)
assert.Equals(t, dbc.AuthzID, ch.AuthzID)
assert.Equals(t, dbc.Type, ch.Type)
assert.Equals(t, dbc.Status, ch.Status)
assert.Equals(t, dbc.Token, ch.Token)
assert.Equals(t, dbc.Value, ch.Value)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, true, nil
},
},
_id: idPtr,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.ch.ID, *tc._id)
}
}
})
}
}
func TestDB_GetChallenge(t *testing.T) {
chID := "chID"
azID := "azID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
dbc *dbChallenge
}
var tests = map[string]func(t *testing.T) test{
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, errors.New("force")
},
},
err: errors.New("error loading acme challenge chID: force"),
}
},
"fail/forward-acme-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, nosqldb.ErrNotFound
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
}
},
"ok": func(t *testing.T) test {
dbc := &dbChallenge{
ID: chID,
AccountID: "accountID",
AuthzID: azID,
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
CreatedAt: clock.Now(),
ValidatedAt: "foobar",
Error: acme.NewErrorISE("force"),
}
b, err := json.Marshal(dbc)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
},
dbc: dbc,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID)
assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
}
})
}
}
func TestDB_UpdateChallenge(t *testing.T) {
chID := "chID"
dbc := &dbChallenge{
ID: chID,
AccountID: "accountID",
AuthzID: "azID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
CreatedAt: clock.Now(),
}
b, err := json.Marshal(dbc)
assert.FatalError(t, err)
type test struct {
db nosql.DB
ch *acme.Challenge
err error
}
var tests = map[string]func(t *testing.T) test{
"fail/db.Get-error": func(t *testing.T) test {
return test{
ch: &acme.Challenge{
ID: chID,
},
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, errors.New("force")
},
},
err: errors.New("error loading acme challenge chID: force"),
}
},
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
updCh := &acme.Challenge{
ID: chID,
Status: acme.StatusValid,
ValidatedAt: "foobar",
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
}
return test{
ch: updCh,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, old, b)
dbOld := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(old, dbOld))
assert.Equals(t, dbc, dbOld)
dbNew := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbNew))
assert.Equals(t, dbNew.ID, dbc.ID)
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
assert.Equals(t, dbNew.AuthzID, dbc.AuthzID)
assert.Equals(t, dbNew.Type, dbc.Type)
assert.Equals(t, dbNew.Status, updCh.Status)
assert.Equals(t, dbNew.Token, dbc.Token)
assert.Equals(t, dbNew.Value, dbc.Value)
assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
return nil, false, errors.New("force")
},
},
err: errors.New("error saving acme challenge: force"),
}
},
"ok": func(t *testing.T) test {
updCh := &acme.Challenge{
ID: dbc.ID,
AccountID: dbc.AccountID,
AuthzID: dbc.AuthzID,
Type: dbc.Type,
Token: dbc.Token,
Value: dbc.Value,
Status: acme.StatusValid,
ValidatedAt: "foobar",
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
}
return test{
ch: updCh,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, old, b)
dbOld := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(old, dbOld))
assert.Equals(t, dbc, dbOld)
dbNew := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbNew))
assert.Equals(t, dbNew.ID, dbc.ID)
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
assert.Equals(t, dbNew.AuthzID, dbc.AuthzID)
assert.Equals(t, dbNew.Type, dbc.Type)
assert.Equals(t, dbNew.Token, dbc.Token)
assert.Equals(t, dbNew.Value, dbc.Value)
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
assert.Equals(t, dbNew.Status, acme.StatusValid)
assert.Equals(t, dbNew.ValidatedAt, "foobar")
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
return nu, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.ch.ID, dbc.ID)
assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
assert.Equals(t, tc.ch.AuthzID, dbc.AuthzID)
assert.Equals(t, tc.ch.Type, dbc.Type)
assert.Equals(t, tc.ch.Token, dbc.Token)
assert.Equals(t, tc.ch.Value, dbc.Value)
assert.Equals(t, tc.ch.ValidatedAt, "foobar")
assert.Equals(t, tc.ch.Status, acme.StatusValid)
assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
}
}
})
}
}

@ -8,14 +8,19 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
"github.com/smallstep/nosql"
)
// dbNonce contains nonce metadata used in the ACME protocol.
type dbNonce struct {
ID string
Created time.Time
ID string
CreatedAt time.Time
DeletedAt time.Time
}
func (dbn *dbNonce) clone() *dbNonce {
u := *dbn
return &u
}
// CreateNonce creates, stores, and returns an ACME replay-nonce.
@ -28,14 +33,10 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
n := &dbNonce{
ID: id,
Created: clock.Now(),
}
b, err := json.Marshal(n)
if err != nil {
return "", errors.Wrap(err, "error marshaling nonce")
ID: id,
CreatedAt: clock.Now(),
}
if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil {
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
return "", err
}
return acme.Nonce(id), nil
@ -44,27 +45,24 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
// DeleteNonce verifies that the nonce is valid (by checking if it exists),
// and if so, consumes the nonce resource by deleting it from the database.
func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
err := db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: nonceTable,
Key: []byte(nonce),
Cmd: database.Get,
},
{
Bucket: nonceTable,
Key: []byte(nonce),
Cmd: database.Delete,
},
},
})
id := string(nonce)
b, err := db.db.Get(nonceTable, []byte(nonce))
if nosql.IsErrNotFound(err) {
return errors.Wrapf(err, "nonce %s not found", id)
} else if err != nil {
return errors.Wrapf(err, "error loading nonce %s", id)
}
switch {
case nosqlDB.IsErrNotFound(err):
return errors.New("not found")
case err != nil:
return errors.Wrapf(err, "error deleting nonce %s", nonce)
default:
return nil
dbn := new(dbNonce)
if err := json.Unmarshal(b, dbn); err != nil {
return errors.Wrapf(err, "error unmarshaling nonce %s", string(nonce))
}
if !dbn.DeletedAt.IsZero() {
return acme.NewError(acme.ErrorBadNonceType, "nonce %s already deleted", id)
}
nu := dbn.clone()
nu.DeletedAt = clock.Now()
return db.save(ctx, id, nu, dbn, "nonce", nonceTable)
}

@ -0,0 +1,209 @@
package nosql
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
)
func TestDB_CreateNonce(t *testing.T) {
type test struct {
db nosql.DB
nonce *acme.Nonce
err error
_id *string
}
var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
dbn := new(dbNonce)
assert.FatalError(t, json.Unmarshal(nu, dbn))
assert.Equals(t, dbn.ID, string(key))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
return nil, false, errors.New("force")
},
},
err: errors.New("error saving acme nonce: force"),
}
},
"ok": func(t *testing.T) test {
var (
id string
idPtr = &id
)
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
*idPtr = string(key)
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
dbn := new(dbNonce)
assert.FatalError(t, json.Unmarshal(nu, dbn))
assert.Equals(t, dbn.ID, string(key))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
return nil, true, nil
},
},
_id: idPtr,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if n, err := db.CreateNonce(context.Background()); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, string(n), *tc._id)
}
}
})
}
}
func TestDB_DeleteNonce(t *testing.T) {
nonceID := "nonceID"
type test struct {
db nosql.DB
err error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
return nil, nosqldb.ErrNotFound
},
},
err: errors.New("nonce nonceID not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
return nil, errors.Errorf("force")
},
},
err: errors.New("error loading nonce nonceID: force"),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
a := []string{"foo", "bar", "baz"}
b, err := json.Marshal(a)
assert.FatalError(t, err)
return b, nil
},
},
err: errors.New("error unmarshaling nonce nonceID"),
}
},
"fail/already-used": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
nonce := dbNonce{
ID: nonceID,
CreatedAt: clock.Now().Add(-5 * time.Minute),
DeletedAt: clock.Now(),
}
b, err := json.Marshal(nonce)
assert.FatalError(t, err)
return b, nil
},
},
err: acme.NewError(acme.ErrorBadNonceType, "nonce already deleted"),
}
},
"ok": func(t *testing.T) test {
nonce := dbNonce{
ID: nonceID,
CreatedAt: clock.Now().Add(-5 * time.Minute),
}
b, err := json.Marshal(nonce)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, b)
dbo := new(dbNonce)
assert.FatalError(t, json.Unmarshal(old, dbo))
assert.Equals(t, dbo.ID, string(key))
assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbo.CreatedAt))
assert.True(t, clock.Now().Add(-4*time.Minute).After(dbo.CreatedAt))
assert.True(t, dbo.DeletedAt.IsZero())
dbn := new(dbNonce)
assert.FatalError(t, json.Unmarshal(nu, dbn))
assert.Equals(t, dbn.ID, string(key))
assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbn.CreatedAt))
assert.True(t, clock.Now().Add(-4*time.Minute).After(dbn.CreatedAt))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.DeletedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbn.DeletedAt))
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
Loading…
Cancel
Save