From 088432150d0ab53aeb436f9a37a08c316bff90e7 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Feb 2021 10:23:11 -0800 Subject: [PATCH 01/47] Beginnings of acmeDB interface --- acme/account.go | 197 ---------------------------------------------- acme/authority.go | 61 +++++--------- acme/nonce.go | 73 ----------------- 3 files changed, 20 insertions(+), 311 deletions(-) delete mode 100644 acme/account.go delete mode 100644 acme/nonce.go diff --git a/acme/account.go b/acme/account.go deleted file mode 100644 index 1c5870d5..00000000 --- a/acme/account.go +++ /dev/null @@ -1,197 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" -) - -// Account is a subset of the internal account type containing only those -// attributes required for responses in the ACME protocol. -type Account struct { - Contact []string `json:"contact,omitempty"` - Status string `json:"status"` - Orders string `json:"orders"` - ID string `json:"-"` - Key *jose.JSONWebKey `json:"-"` -} - -// ToLog enables response logging. -func (a *Account) ToLog() (interface{}, error) { - b, err := json.Marshal(a) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging")) - } - return string(b), nil -} - -// GetID returns the account ID. -func (a *Account) GetID() string { - return a.ID -} - -// GetKey returns the JWK associated with the account. -func (a *Account) GetKey() *jose.JSONWebKey { - return a.Key -} - -// IsValid returns true if the Account is valid. -func (a *Account) IsValid() bool { - return a.Status == StatusValid -} - -// AccountOptions are the options needed to create a new ACME account. -type AccountOptions struct { - Key *jose.JSONWebKey - Contact []string -} - -// account represents an ACME account. -type account struct { - ID string `json:"id"` - Created time.Time `json:"created"` - Deactivated time.Time `json:"deactivated"` - Key *jose.JSONWebKey `json:"key"` - Contact []string `json:"contact,omitempty"` - Status string `json:"status"` -} - -// newAccount returns a new acme account type. -func newAccount(db nosql.DB, ops AccountOptions) (*account, error) { - id, err := randID() - if err != nil { - return nil, err - } - - a := &account{ - ID: id, - Key: ops.Key, - Contact: ops.Contact, - Status: "valid", - Created: clock.Now(), - } - return a, a.saveNew(db) -} - -// toACME converts the internal Account type into the public acmeAccount -// type for presentation in the ACME protocol. -func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) { - return &Account{ - Status: a.Status, - Contact: a.Contact, - Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), - Key: a.Key, - ID: a.ID, - }, nil -} - -// save writes the Account to the DB. -// If the account is new then the necessary indices will be created. -// Else, the account in the DB will be updated. -func (a *account) saveNew(db nosql.DB) error { - kid, err := keyToID(a.Key) - if err != nil { - return err - } - kidB := []byte(kid) - - // Set the jwkID -> acme account ID index - _, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID)) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index")) - case !swapped: - return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) - default: - if err = a.save(db, nil); err != nil { - db.Del(accountByKeyIDTable, kidB) - return err - } - return nil - } -} - -func (a *account) save(db nosql.DB, old *account) error { - var ( - err error - oldB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) - } - } - - b, err := json.Marshal(*a) - if err != nil { - return errors.Wrap(err, "error marshaling new account object") - } - // Set the Account - _, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error storing account")) - case !swapped: - return ServerInternalErr(errors.New("error storing account; " + - "value has changed since last read")) - default: - return nil - } -} - -// update updates the acme account object stored in the database if, -// and only if, the account has not changed since the last read. -func (a *account) update(db nosql.DB, contact []string) (*account, error) { - b := *a - b.Contact = contact - if err := b.save(db, a); err != nil { - return nil, err - } - return &b, nil -} - -// deactivate deactivates the acme account. -func (a *account) deactivate(db nosql.DB) (*account, error) { - b := *a - b.Status = StatusDeactivated - b.Deactivated = clock.Now() - if err := b.save(db, a); err != nil { - return nil, err - } - return &b, nil -} - -// getAccountByID retrieves the account with the given ID. -func getAccountByID(db nosql.DB, id string) (*account, error) { - ab, err := db.Get(accountTable, []byte(id)) - if err != nil { - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) - } - - a := new(account) - if err = json.Unmarshal(ab, a); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) - } - return a, nil -} - -// getAccountByKeyID retrieves Id associated with the given Kid. -func getAccountByKeyID(db nosql.DB, kid string) (*account, error) { - id, err := db.Get(accountByKeyIDTable, []byte(kid)) - if err != nil { - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) - } - return getAccountByID(db, string(id)) -} diff --git a/acme/authority.go b/acme/authority.go index 0f5f2c9f..6de3a5e1 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -92,7 +92,7 @@ func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Aut }) } -// New returns a new Autohrity that implements the ACME interface. +// New returns a new Authority that implements the ACME interface. func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { if _, ok := ops.DB.(*database.SimpleDB); !ok { // If it's not a SimpleDB then go ahead and bootstrap the DB with the @@ -140,59 +140,41 @@ func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error } // NewNonce generates, stores, and returns a new ACME nonce. -func (a *Authority) NewNonce() (string, error) { - n, err := newNonce(a.db) - if err != nil { - return "", err - } - return n.ID, nil +func (a *Authority) NewNonce(ctx context.Context) (string, error) { + return a.db.CreateNonce(ctx) } // UseNonce consumes the given nonce if it is valid, returns error otherwise. -func (a *Authority) UseNonce(nonce string) error { - return useNonce(a.db, nonce) +func (a *Authority) UseNonce(ctx context.Context, nonce string) error { + return a.db.DeleteNonce(ctx, nonce) } // NewAccount creates, stores, and returns a new ACME account. func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) { - acc, err := newAccount(a.db, ao) - if err != nil { - return nil, err + a := NewAccount(ao) + if err := a.db.CreateAccount(ctx, a); err != nil { + return ServerInternalErr(err) } - return acc.toACME(ctx, a.db, a.dir) + return a, nil } // UpdateAccount updates an ACME account. -func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) { - acc, err := getAccountByID(a.db, id) +func (a *Authority) UpdateAccount(ctx context.Context, auo AccountUpdateOptions) (*Account, error) { + acc, err := a.db.GetAccount(ctx, auo.ID) if err != nil { - return nil, ServerInternalErr(err) + return ServerInternalErr(err) } - if acc, err = acc.update(a.db, contact); err != nil { - return nil, err + acc.Contact = auo.Contact + acc.Status = auo.Status + if err = a.db.UpdateAccount(ctx, acc); err != nil { + return ServerInternalErr(err) } - return acc.toACME(ctx, a.db, a.dir) + return acc, nil } // GetAccount returns an ACME account. func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { - acc, err := getAccountByID(a.db, id) - if err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) -} - -// DeactivateAccount deactivates an ACME account. -func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) { - acc, err := getAccountByID(a.db, id) - if err != nil { - return nil, err - } - if acc, err = acc.deactivate(a.db); err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) + return a.db.GetAccount(ctx, id) } func keyToID(jwk *jose.JSONWebKey) (string, error) { @@ -209,11 +191,8 @@ func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) ( if err != nil { return nil, err } - acc, err := getAccountByKeyID(a.db, kid) - if err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) + acc, err := a.db.GetAccountByKeyID(ctx, kid) + return acc, err } // GetOrder returns an ACME order. diff --git a/acme/nonce.go b/acme/nonce.go deleted file mode 100644 index db680f08..00000000 --- a/acme/nonce.go +++ /dev/null @@ -1,73 +0,0 @@ -package acme - -import ( - "encoding/base64" - "encoding/json" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - -// nonce contains nonce metadata used in the ACME protocol. -type nonce struct { - ID string - Created time.Time -} - -// newNonce creates, stores, and returns an ACME replay-nonce. -func newNonce(db nosql.DB) (*nonce, error) { - _id, err := randID() - if err != nil { - return nil, err - } - - id := base64.RawURLEncoding.EncodeToString([]byte(_id)) - n := &nonce{ - ID: id, - Created: clock.Now(), - } - b, err := json.Marshal(n) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) - } - _, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b) - switch { - case err != nil: - return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce")) - case !swapped: - return nil, ServerInternalErr(errors.New("error storing nonce; " + - "value has changed since last read")) - default: - return n, nil - } -} - -// useNonce verifies that the nonce is valid (by checking if it exists), -// and if so, consumes the nonce resource by deleting it from the database. -func useNonce(db nosql.DB, nonce string) error { - err := db.Update(&database.Tx{ - Operations: []*database.TxEntry{ - { - Bucket: nonceTable, - Key: []byte(nonce), - Cmd: database.Get, - }, - { - Bucket: nonceTable, - Key: []byte(nonce), - Cmd: database.Delete, - }, - }, - }) - - switch { - case nosql.IsErrNotFound(err): - return BadNonceErr(nil) - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce)) - default: - return nil - } -} From 34859551ef9086b19b5be36725d723a7a943ef1d Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Feb 2021 10:24:24 -0800 Subject: [PATCH 02/47] Add new directory structure --- acme/db/db.go | 31 ++++++++ acme/db/nosql/account.go | 138 +++++++++++++++++++++++++++++++++++ acme/db/nosql/authz.go | 0 acme/db/nosql/certificate.go | 0 acme/db/nosql/challenge.go | 0 acme/db/nosql/nonce.go | 74 +++++++++++++++++++ acme/db/nosql/nosql.go | 10 +++ acme/db/nosql/order.go | 0 acme/types/account.go | 66 +++++++++++++++++ acme/types/authz.go | 32 ++++++++ acme/types/nonce.go | 3 + 11 files changed, 354 insertions(+) create mode 100644 acme/db/db.go create mode 100644 acme/db/nosql/account.go create mode 100644 acme/db/nosql/authz.go create mode 100644 acme/db/nosql/certificate.go create mode 100644 acme/db/nosql/challenge.go create mode 100644 acme/db/nosql/nonce.go create mode 100644 acme/db/nosql/nosql.go create mode 100644 acme/db/nosql/order.go create mode 100644 acme/types/account.go create mode 100644 acme/types/authz.go create mode 100644 acme/types/nonce.go diff --git a/acme/db/db.go b/acme/db/db.go new file mode 100644 index 00000000..2882fcf4 --- /dev/null +++ b/acme/db/db.go @@ -0,0 +1,31 @@ +package acme + +import "context" + +// DB is the DB interface expected by the step-ca ACME API. +type DB interface { + CreateAccount(ctx context.Context, acc *Account) (*Account, error) + GetAccount(ctx context.Context, id string) (*Account, error) + GetAccountByKeyID(ctx context.Context) (*Account, error) + UpdateAccount(ctx context.Context, acc *Account) error + + CreateNonce(ctx context.Context) (Nonce, error) + DeleteNonce(ctx context.Context, nonce Nonce) error + + CreateAuthorization(ctx context.Context, authz *Authorization) error + GetAuthorization(ctx context.Context, id string) (*Authorization, error) + UpdateAuthorization(ctx context.Context, authz *Authorization) error + + CreateCertificate(ctx context.Context, cert *Certificate) error + GetCertificate(ctx context.Context, id string) (*Certificate, error) + + CreateChallenge(ctx context.Context, ch *Challenge) error + GetChallenge(ctx context.Context, id string) (*Challenge, error) + UpdateChallenge(ctx context.Context, ch *Challenge) error + + CreateOrder(ctx context.Context, o *Order) error + DeleteOrder(ctx context.Context, id string) error + GetOrder(ctx context.Context, id string) (*Order, error) + GetOrdersByAccountID(ctx context.Context, accountID string) error + UpdateOrder(ctx context.Context, o *Order) error +} diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go new file mode 100644 index 00000000..dcbfd2f5 --- /dev/null +++ b/acme/db/nosql/account.go @@ -0,0 +1,138 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + nosqlDB "github.com/smallstep/nosql" + "go.step.sm/crypto/jose" +) + +// dbAccount represents an ACME account. +type dbAccount struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Deactivated time.Time `json:"deactivated"` + Key *jose.JSONWebKey `json:"key"` + Contact []string `json:"contact,omitempty"` + Status string `json:"status"` +} + +func (db *DB) saveAccount(nu *dbAccount, old *dbAccount) error { + var ( + err error + oldB []byte + ) + if old == nil { + oldB = nil + } else { + if oldB, err = json.Marshal(old); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) + } + } + + b, err := json.Marshal(*nu) + if err != nil { + return errors.Wrap(err, "error marshaling new account object") + } + // Set the Account + _, swapped, err := db.CmpAndSwap(accountTable, []byte(nu.ID), oldB, b) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error storing account")) + case !swapped: + return ServerInternalErr(errors.New("error storing account; " + + "value has changed since last read")) + default: + return nil + } +} + +// CreateAccount imlements the AcmeDB.CreateAccount interface. +func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { + id, err := randID() + if err != nil { + return nil, err + } + + dba := &dbAccount{ + ID: id, + Key: acc.Key, + Contact: acc.Contact, + Status: acc.Valid, + Created: clock.Now(), + } + + kid, err := keyToID(dba.Key) + if err != nil { + return err + } + kidB := []byte(kid) + + // Set the jwkID -> acme account ID index + _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID)) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index")) + case !swapped: + return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) + default: + if err = db.saveAccount(dba, nil); err != nil { + db.db.Del(accountByKeyIDTable, kidB) + return err + } + return nil + } +} + +// UpdateAccount imlements the AcmeDB.UpdateAccount interface. +func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error { + kid, err := keyToID(dba.Key) + if err != nil { + return err + } + + dba, err := db.db.getAccountByKeyID(ctx, kid) + + newdba := *dba + newdba.Contact = acc.contact + newdba.Status = acc.Status + + // If the status has changed to 'deactivated', then set deactivatedAt timestamp. + if acc.Status == types.StatusDeactivated && dba.Status != types.Status.Deactivated { + newdba.Deactivated = clock.Now() + } + + return db.saveAccount(newdba, dba) +} + +// getAccountByID retrieves the account with the given ID. +func (db *DB) getAccountByID(ctx context.Context, id string) (*dbAccount, error) { + ab, err := db.db.Get(accountTable, []byte(id)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) + } + + a := new(account) + if err = json.Unmarshal(ab, a); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) + } + return a, nil +} + +// getAccountByKeyID retrieves Id associated with the given Kid. +func (db *DB) getAccountByKeyID(ctx context.Context, kid string) (*dbAccount, error) { + id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) + } + return getAccountByID(db, string(id)) +} diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go new file mode 100644 index 00000000..e69de29b diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go new file mode 100644 index 00000000..e69de29b diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go new file mode 100644 index 00000000..e69de29b diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go new file mode 100644 index 00000000..3459f212 --- /dev/null +++ b/acme/db/nosql/nonce.go @@ -0,0 +1,74 @@ +package nosql + +import ( + "encoding/base64" + "encoding/json" + "time" + + "github.com/pkg/errors" + nosqlDB "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +// dbNonce contains nonce metadata used in the ACME protocol. +type dbNonce struct { + ID string + Created time.Time +} + +// CreateNonce creates, stores, and returns an ACME replay-nonce. +// Implements the acme.DB interface. +func (db *DB) CreateNonce() (Nonce, error) { + _id, err := randID() + if err != nil { + return nil, err + } + + id := base64.RawURLEncoding.EncodeToString([]byte(_id)) + n := &dbNonce{ + ID: id, + Created: clock.Now(), + } + b, err := json.Marshal(n) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) + } + _, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b) + switch { + case err != nil: + return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce")) + case !swapped: + return nil, ServerInternalErr(errors.New("error storing nonce; " + + "value has changed since last read")) + default: + return Nonce(id), nil + } +} + +// DeleteNonce verifies that the nonce is valid (by checking if it exists), +// and if so, consumes the nonce resource by deleting it from the database. +func (db *DB) DeleteNonce(nonce string) error { + err := db.db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Get, + }, + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Delete, + }, + }, + }) + + switch { + case nosqlDB.IsErrNotFound(err): + return BadNonceErr(nil) + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce)) + default: + return nil + } +} diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go new file mode 100644 index 00000000..8bfd1a66 --- /dev/null +++ b/acme/db/nosql/nosql.go @@ -0,0 +1,10 @@ +package nosql + +import ( + nosqlDB "github.com/smallstep/nosql" +) + +// DB is a struct that implements the AcmeDB interface. +type DB struct { + db nosqlDB.DB +} diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go new file mode 100644 index 00000000..e69de29b diff --git a/acme/types/account.go b/acme/types/account.go new file mode 100644 index 00000000..ea40a646 --- /dev/null +++ b/acme/types/account.go @@ -0,0 +1,66 @@ +package acme + +import ( + "encoding/json" + + "github.com/pkg/errors" + "go.step.sm/crypto/jose" +) + +// Account is a subset of the internal account type containing only those +// attributes required for responses in the ACME protocol. +type Account struct { + Contact []string `json:"contact,omitempty"` + Status string `json:"status"` + Orders string `json:"orders"` + ID string `json:"-"` + Key *jose.JSONWebKey `json:"-"` +} + +// ToLog enables response logging. +func (a *Account) ToLog() (interface{}, error) { + b, err := json.Marshal(a) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging")) + } + return string(b), nil +} + +// GetID returns the account ID. +func (a *Account) GetID() string { + return a.ID +} + +// GetKey returns the JWK associated with the account. +func (a *Account) GetKey() *jose.JSONWebKey { + return a.Key +} + +// IsValid returns true if the Account is valid. +func (a *Account) IsValid() bool { + return a.Status == StatusValid +} + +// AccountOptions are the options needed to create a new ACME account. +type AccountOptions struct { + Key *jose.JSONWebKey + Contact []string +} + +// AccountUpdateOptions are the options needed to update an existing ACME account. +type AccountUpdateOptions struct { + Contact []string + Status types.Status +} + +// toACME converts the internal Account type into the public acmeAccount +// type for presentation in the ACME protocol. +//func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) { +// return &Account{ +// Status: a.Status, +// Contact: a.Contact, +// Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), +// Key: a.Key, +// ID: a.ID, +// }, nil +//} diff --git a/acme/types/authz.go b/acme/types/authz.go new file mode 100644 index 00000000..3e3a5aa7 --- /dev/null +++ b/acme/types/authz.go @@ -0,0 +1,32 @@ +package types + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +// Authz is a subset of the Authz type containing only those attributes +// required for responses in the ACME protocol. +type Authz struct { + Identifier Identifier `json:"identifier"` + Status string `json:"status"` + Expires string `json:"expires"` + Challenges []*Challenge `json:"challenges"` + Wildcard bool `json:"wildcard"` + ID string `json:"-"` +} + +// ToLog enables response logging. +func (a *Authz) ToLog() (interface{}, error) { + b, err := json.Marshal(a) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) + } + return string(b), nil +} + +// GetID returns the Authz ID. +func (a *Authz) GetID() string { + return a.ID +} diff --git a/acme/types/nonce.go b/acme/types/nonce.go new file mode 100644 index 00000000..4234e818 --- /dev/null +++ b/acme/types/nonce.go @@ -0,0 +1,3 @@ +package acme + +type Nonce string From 31ad7f2e9b30dd41cb2e24fa28212515b54d5a47 Mon Sep 17 00:00:00 2001 From: max furman Date: Fri, 26 Feb 2021 10:12:30 -0800 Subject: [PATCH 03/47] [acme] Continued work on acme db interface (wip) --- acme/authority.go | 10 + acme/authz.go | 347 --------------------- acme/challenge.go | 620 ------------------------------------- acme/db/db.go | 34 +- acme/db/nosql/account.go | 89 ++++-- acme/db/nosql/authz.go | 216 +++++++++++++ acme/db/nosql/challenge.go | 528 +++++++++++++++++++++++++++++++ acme/db/nosql/order.go | 469 ++++++++++++++++++++++++++++ acme/types/authz.go | 11 +- acme/types/challenge.go | 30 ++ acme/types/order.go | 30 ++ acme/types/status.go | 20 ++ 12 files changed, 1382 insertions(+), 1022 deletions(-) delete mode 100644 acme/authz.go delete mode 100644 acme/challenge.go create mode 100644 acme/types/challenge.go create mode 100644 acme/types/order.go create mode 100644 acme/types/status.go diff --git a/acme/authority.go b/acme/authority.go index 6de3a5e1..c9190811 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -319,3 +319,13 @@ func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) { } return cert.toACME(a.db, a.dir) } + +type httpGetter func(string) (*http.Response, error) +type lookupTxt func(string) ([]string, error) +type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) + +type validateOptions struct { + httpGet httpGetter + lookupTxt lookupTxt + tlsDial tlsDialer +} diff --git a/acme/authz.go b/acme/authz.go deleted file mode 100644 index 8c45bce0..00000000 --- a/acme/authz.go +++ /dev/null @@ -1,347 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "strings" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/nosql" -) - -var defaultExpiryDuration = time.Hour * 24 - -// Authz is a subset of the Authz type containing only those attributes -// required for responses in the ACME protocol. -type Authz struct { - Identifier Identifier `json:"identifier"` - Status string `json:"status"` - Expires string `json:"expires"` - Challenges []*Challenge `json:"challenges"` - Wildcard bool `json:"wildcard"` - ID string `json:"-"` -} - -// ToLog enables response logging. -func (a *Authz) ToLog() (interface{}, error) { - b, err := json.Marshal(a) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) - } - return string(b), nil -} - -// GetID returns the Authz ID. -func (a *Authz) GetID() string { - return a.ID -} - -// authz is the interface that the various authz types must implement. -type authz interface { - save(nosql.DB, authz) error - clone() *baseAuthz - getID() string - getAccountID() string - getType() string - getIdentifier() Identifier - getStatus() string - getExpiry() time.Time - getWildcard() bool - getChallenges() []string - getCreated() time.Time - updateStatus(db nosql.DB) (authz, error) - toACME(context.Context, nosql.DB, *directory) (*Authz, error) -} - -// baseAuthz is the base authz type that others build from. -type baseAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier Identifier `json:"identifier"` - Status string `json:"status"` - Expires time.Time `json:"expires"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` - Error *Error `json:"error"` -} - -func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) { - id, err := randID() - if err != nil { - return nil, err - } - - now := clock.Now() - ba := &baseAuthz{ - ID: id, - AccountID: accID, - Status: StatusPending, - Created: now, - Expires: now.Add(defaultExpiryDuration), - Identifier: identifier, - } - - if strings.HasPrefix(identifier.Value, "*.") { - ba.Wildcard = true - ba.Identifier = Identifier{ - Value: strings.TrimPrefix(identifier.Value, "*."), - Type: identifier.Type, - } - } - - return ba, nil -} - -// getID returns the ID of the authz. -func (ba *baseAuthz) getID() string { - return ba.ID -} - -// getAccountID returns the Account ID that created the authz. -func (ba *baseAuthz) getAccountID() string { - return ba.AccountID -} - -// getType returns the type of the authz. -func (ba *baseAuthz) getType() string { - return ba.Identifier.Type -} - -// getIdentifier returns the identifier for the authz. -func (ba *baseAuthz) getIdentifier() Identifier { - return ba.Identifier -} - -// getStatus returns the status of the authz. -func (ba *baseAuthz) getStatus() string { - return ba.Status -} - -// getWildcard returns true if the authz identifier has a '*', false otherwise. -func (ba *baseAuthz) getWildcard() bool { - return ba.Wildcard -} - -// getChallenges returns the authz challenge IDs. -func (ba *baseAuthz) getChallenges() []string { - return ba.Challenges -} - -// getExpiry returns the expiration time of the authz. -func (ba *baseAuthz) getExpiry() time.Time { - return ba.Expires -} - -// getCreated returns the created time of the authz. -func (ba *baseAuthz) getCreated() time.Time { - return ba.Created -} - -// toACME converts the internal Authz type into the public acmeAuthz type for -// presentation in the ACME protocol. -func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) { - var chs = make([]*Challenge, len(ba.Challenges)) - for i, chID := range ba.Challenges { - ch, err := getChallenge(db, chID) - if err != nil { - return nil, err - } - chs[i], err = ch.toACME(ctx, db, dir) - if err != nil { - return nil, err - } - } - return &Authz{ - Identifier: ba.Identifier, - Status: ba.getStatus(), - Challenges: chs, - Wildcard: ba.getWildcard(), - Expires: ba.Expires.Format(time.RFC3339), - ID: ba.ID, - }, nil -} - -func (ba *baseAuthz) save(db nosql.DB, old authz) error { - var ( - err error - oldB, newB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old authz")) - } - } - if newB, err = json.Marshal(ba); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new authz")) - } - _, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing authz")) - case !swapped: - return ServerInternalErr(errors.Errorf("error storing authz; " + - "value has changed since last read")) - default: - return nil - } -} - -func (ba *baseAuthz) clone() *baseAuthz { - u := *ba - return &u -} - -func (ba *baseAuthz) parent() authz { - return &dnsAuthz{ba} -} - -// updateStatus attempts to update the status on a baseAuthz and stores the -// updating object if necessary. -func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) { - newAuthz := ba.clone() - - now := time.Now().UTC() - switch ba.Status { - case StatusInvalid: - return ba.parent(), nil - case StatusValid: - return ba.parent(), nil - case StatusPending: - // check expiry - if now.After(ba.Expires) { - newAuthz.Status = StatusInvalid - newAuthz.Error = MalformedErr(errors.New("authz has expired")) - break - } - - var isValid = false - for _, chID := range ba.Challenges { - ch, err := getChallenge(db, chID) - if err != nil { - return ba, err - } - if ch.getStatus() == StatusValid { - isValid = true - break - } - } - - if !isValid { - return ba.parent(), nil - } - newAuthz.Status = StatusValid - newAuthz.Error = nil - default: - return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) - } - - if err := newAuthz.save(db, ba); err != nil { - return ba, err - } - return newAuthz.parent(), nil -} - -// unmarshalAuthz unmarshals an authz type into the correct sub-type. -func unmarshalAuthz(data []byte) (authz, error) { - var getType struct { - Identifier Identifier `json:"identifier"` - } - if err := json.Unmarshal(data, &getType); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type")) - } - - switch getType.Identifier.Type { - case "dns": - var ba baseAuthz - if err := json.Unmarshal(data, &ba); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz")) - } - return &dnsAuthz{&ba}, nil - default: - return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s", - getType.Identifier.Type)) - } -} - -// dnsAuthz represents a dns acme authorization. -type dnsAuthz struct { - *baseAuthz -} - -// newAuthz returns a new acme authorization object based on the identifier -// type. -func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) { - switch identifier.Type { - case "dns": - a, err = newDNSAuthz(db, accID, identifier) - default: - err = MalformedErr(errors.Errorf("unexpected authz type %s", - identifier.Type)) - } - return -} - -// newDNSAuthz returns a new dns acme authorization object. -func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) { - ba, err := newBaseAuthz(accID, identifier) - if err != nil { - return nil, err - } - - ba.Challenges = []string{} - if !ba.Wildcard { - // http and alpn challenges are only permitted if the DNS is not a wildcard dns. - ch1, err := newHTTP01Challenge(db, ChallengeOptions{ - AccountID: accID, - AuthzID: ba.ID, - Identifier: ba.Identifier}) - if err != nil { - return nil, Wrap(err, "error creating http challenge") - } - ba.Challenges = append(ba.Challenges, ch1.getID()) - - ch2, err := newTLSALPN01Challenge(db, ChallengeOptions{ - AccountID: accID, - AuthzID: ba.ID, - Identifier: ba.Identifier, - }) - if err != nil { - return nil, Wrap(err, "error creating alpn challenge") - } - ba.Challenges = append(ba.Challenges, ch2.getID()) - } - ch3, err := newDNS01Challenge(db, ChallengeOptions{ - AccountID: accID, - AuthzID: ba.ID, - Identifier: identifier}) - if err != nil { - return nil, Wrap(err, "error creating dns challenge") - } - ba.Challenges = append(ba.Challenges, ch3.getID()) - - da := &dnsAuthz{ba} - if err := da.save(db, nil); err != nil { - return nil, err - } - - return da, nil -} - -// getAuthz retrieves and unmarshals an ACME authz type from the database. -func getAuthz(db nosql.DB, id string) (authz, error) { - b, err := db.Get(authzTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id)) - } - az, err := unmarshalAuthz(b) - if err != nil { - return nil, err - } - return az, nil -} diff --git a/acme/challenge.go b/acme/challenge.go deleted file mode 100644 index 6d2d13d1..00000000 --- a/acme/challenge.go +++ /dev/null @@ -1,620 +0,0 @@ -package acme - -import ( - "context" - "crypto" - "crypto/sha256" - "crypto/subtle" - "crypto/tls" - "encoding/asn1" - "encoding/base64" - "encoding/hex" - "encoding/json" - "fmt" - "io/ioutil" - "net" - "net/http" - "strings" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" -) - -// Challenge is a subset of the challenge type containing only those attributes -// required for responses in the ACME protocol. -type Challenge struct { - Type string `json:"type"` - Status string `json:"status"` - Token string `json:"token"` - Validated string `json:"validated,omitempty"` - URL string `json:"url"` - Error *AError `json:"error,omitempty"` - ID string `json:"-"` - AuthzID string `json:"-"` -} - -// ToLog enables response logging. -func (c *Challenge) ToLog() (interface{}, error) { - b, err := json.Marshal(c) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) - } - return string(b), nil -} - -// GetID returns the Challenge ID. -func (c *Challenge) GetID() string { - return c.ID -} - -// GetAuthzID returns the parent Authz ID that owns the Challenge. -func (c *Challenge) GetAuthzID() string { - return c.AuthzID -} - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -type validateOptions struct { - httpGet httpGetter - lookupTxt lookupTxt - tlsDial tlsDialer -} - -// challenge is the interface ACME challenege types must implement. -type challenge interface { - save(db nosql.DB, swap challenge) error - validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error) - getType() string - getError() *AError - getValue() string - getStatus() string - getID() string - getAuthzID() string - getToken() string - clone() *baseChallenge - getAccountID() string - getValidated() time.Time - getCreated() time.Time - toACME(context.Context, nosql.DB, *directory) (*Challenge, error) -} - -// ChallengeOptions is the type used to created a new Challenge. -type ChallengeOptions struct { - AccountID string - AuthzID string - Identifier Identifier -} - -// baseChallenge is the base Challenge type that others build from. -type baseChallenge struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - AuthzID string `json:"authzID"` - Type string `json:"type"` - Status string `json:"status"` - Token string `json:"token"` - Value string `json:"value"` - Validated time.Time `json:"validated"` - Created time.Time `json:"created"` - Error *AError `json:"error"` -} - -func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) { - id, err := randID() - if err != nil { - return nil, Wrap(err, "error generating random id for ACME challenge") - } - token, err := randID() - if err != nil { - return nil, Wrap(err, "error generating token for ACME challenge") - } - - return &baseChallenge{ - ID: id, - AccountID: accountID, - AuthzID: authzID, - Status: StatusPending, - Token: token, - Created: clock.Now(), - }, nil -} - -// getID returns the id of the baseChallenge. -func (bc *baseChallenge) getID() string { - return bc.ID -} - -// getAuthzID returns the authz ID of the baseChallenge. -func (bc *baseChallenge) getAuthzID() string { - return bc.AuthzID -} - -// getAccountID returns the account id of the baseChallenge. -func (bc *baseChallenge) getAccountID() string { - return bc.AccountID -} - -// getType returns the type of the baseChallenge. -func (bc *baseChallenge) getType() string { - return bc.Type -} - -// getValue returns the type of the baseChallenge. -func (bc *baseChallenge) getValue() string { - return bc.Value -} - -// getStatus returns the status of the baseChallenge. -func (bc *baseChallenge) getStatus() string { - return bc.Status -} - -// getToken returns the token of the baseChallenge. -func (bc *baseChallenge) getToken() string { - return bc.Token -} - -// getValidated returns the validated time of the baseChallenge. -func (bc *baseChallenge) getValidated() time.Time { - return bc.Validated -} - -// getCreated returns the created time of the baseChallenge. -func (bc *baseChallenge) getCreated() time.Time { - return bc.Created -} - -// getCreated returns the created time of the baseChallenge. -func (bc *baseChallenge) getError() *AError { - return bc.Error -} - -// toACME converts the internal Challenge type into the public acmeChallenge -// type for presentation in the ACME protocol. -func (bc *baseChallenge) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Challenge, error) { - ac := &Challenge{ - Type: bc.getType(), - Status: bc.getStatus(), - Token: bc.getToken(), - URL: dir.getLink(ctx, ChallengeLink, true, bc.getID()), - ID: bc.getID(), - AuthzID: bc.getAuthzID(), - } - if !bc.Validated.IsZero() { - ac.Validated = bc.Validated.Format(time.RFC3339) - } - if bc.Error != nil { - ac.Error = bc.Error - } - return ac, nil -} - -// save writes the challenge to disk. For new challenges 'old' should be nil, -// otherwise 'old' should be a pointer to the acme challenge as it was at the -// start of the request. This method will fail if the value currently found -// in the bucket/row does not match the value of 'old'. -func (bc *baseChallenge) save(db nosql.DB, old challenge) error { - newB, err := json.Marshal(bc) - if err != nil { - return ServerInternalErr(errors.Wrap(err, - "error marshaling new acme challenge")) - } - var oldB []byte - if old == nil { - oldB = nil - } else { - oldB, err = json.Marshal(old) - if err != nil { - return ServerInternalErr(errors.Wrap(err, - "error marshaling old acme challenge")) - } - } - - _, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error saving acme challenge")) - case !swapped: - return ServerInternalErr(errors.New("error saving acme challenge; " + - "acme challenge has changed since last read")) - default: - return nil - } -} - -func (bc *baseChallenge) clone() *baseChallenge { - u := *bc - return &u -} - -func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - return nil, ServerInternalErr(errors.New("unimplemented")) -} - -func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error { - clone := bc.clone() - clone.Error = err.ToACME() - if err := clone.save(db, bc); err != nil { - return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge")) - } - return nil -} - -// unmarshalChallenge unmarshals a challenge type into the correct sub-type. -func unmarshalChallenge(data []byte) (challenge, error) { - var getType struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &getType); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type")) - } - - switch getType.Type { - case "dns-01": - var bc baseChallenge - if err := json.Unmarshal(data, &bc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into dns01Challenge")) - } - return &dns01Challenge{&bc}, nil - case "http-01": - var bc baseChallenge - if err := json.Unmarshal(data, &bc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into http01Challenge")) - } - return &http01Challenge{&bc}, nil - case "tls-alpn-01": - var bc baseChallenge - if err := json.Unmarshal(data, &bc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into tlsALPN01Challenge")) - } - return &tlsALPN01Challenge{&bc}, nil - default: - return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type)) - } -} - -// http01Challenge represents an http-01 acme challenge. -type http01Challenge struct { - *baseChallenge -} - -// newHTTP01Challenge returns a new acme http-01 challenge. -func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { - bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) - if err != nil { - return nil, err - } - bc.Type = "http-01" - bc.Value = ops.Identifier.Value - - hc := &http01Challenge{bc} - if err := hc.save(db, nil); err != nil { - return nil, err - } - return hc, nil -} - -// Validate attempts to validate the challenge. If the challenge has been -// satisfactorily validated, the 'status' and 'validated' attributes are -// updated. -func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - // If already valid or invalid then return without performing validation. - if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid { - return hc, nil - } - url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token) - - resp, err := vo.httpGet(url) - if err != nil { - if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err, - "error doing http GET for url %s", url))); err != nil { - return nil, err - } - return hc, nil - } - if resp.StatusCode >= 400 { - if err = hc.storeError(db, - ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", - url, resp.StatusCode))); err != nil { - return nil, err - } - return hc, nil - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+ - "response body for url %s", url)) - } - keyAuth := strings.Trim(string(body), "\r\n") - - expected, err := KeyAuthorization(hc.Token, jwk) - if err != nil { - return nil, err - } - if keyAuth != expected { - if err = hc.storeError(db, - RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got %s", expected, keyAuth))); err != nil { - return nil, err - } - return hc, nil - } - - // Update and store the challenge. - upd := &http01Challenge{hc.baseChallenge.clone()} - upd.Status = StatusValid - upd.Error = nil - upd.Validated = clock.Now() - - if err := upd.save(db, hc); err != nil { - return nil, err - } - return upd, nil -} - -type tlsALPN01Challenge struct { - *baseChallenge -} - -// newTLSALPN01Challenge returns a new acme tls-alpn-01 challenge. -func newTLSALPN01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { - bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) - if err != nil { - return nil, err - } - bc.Type = "tls-alpn-01" - bc.Value = ops.Identifier.Value - - hc := &tlsALPN01Challenge{bc} - if err := hc.save(db, nil); err != nil { - return nil, err - } - return hc, nil -} - -func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - // If already valid or invalid then return without performing validation. - if tc.getStatus() == StatusValid || tc.getStatus() == StatusInvalid { - return tc, nil - } - - config := &tls.Config{ - NextProtos: []string{"acme-tls/1"}, - ServerName: tc.Value, - InsecureSkipVerify: true, // we expect a self-signed challenge certificate - } - - hostPort := net.JoinHostPort(tc.Value, "443") - - conn, err := vo.tlsDial("tcp", hostPort, config) - if err != nil { - if err = tc.storeError(db, - ConnectionErr(errors.Wrapf(err, "error doing TLS dial for %s", hostPort))); err != nil { - return nil, err - } - return tc, nil - } - defer conn.Close() - - cs := conn.ConnectionState() - certs := cs.PeerCertificates - - if len(certs) == 0 { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("%s challenge for %s resulted in no certificates", - tc.Type, tc.Value))); err != nil { - return nil, err - } - return tc, nil - } - - if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("cannot negotiate ALPN acme-tls/1 protocol for "+ - "tls-alpn-01 challenge"))); err != nil { - return nil, err - } - return tc, nil - } - - leafCert := certs[0] - - if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "leaf certificate must contain a single DNS name, %v", tc.Value))); err != nil { - return nil, err - } - return tc, nil - } - - idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} - idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} - foundIDPeAcmeIdentifierV1Obsolete := false - - keyAuth, err := KeyAuthorization(tc.Token, jwk) - if err != nil { - return nil, err - } - hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) - - for _, ext := range leafCert.Extensions { - if idPeAcmeIdentifier.Equal(ext.Id) { - if !ext.Critical { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "acmeValidationV1 extension not critical"))); err != nil { - return nil, err - } - return tc, nil - } - - var extValue []byte - rest, err := asn1.Unmarshal(ext.Value, &extValue) - - if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "malformed acmeValidationV1 extension value"))); err != nil { - return nil, err - } - return tc, nil - } - - if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "expected acmeValidationV1 extension value %s for this challenge but got %s", - hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))); err != nil { - return nil, err - } - return tc, nil - } - - upd := &tlsALPN01Challenge{tc.baseChallenge.clone()} - upd.Status = StatusValid - upd.Error = nil - upd.Validated = clock.Now() - - if err := upd.save(db, tc); err != nil { - return nil, err - } - return upd, nil - } - - if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { - foundIDPeAcmeIdentifierV1Obsolete = true - } - } - - if foundIDPeAcmeIdentifierV1Obsolete { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))); err != nil { - return nil, err - } - return tc, nil - } - - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "missing acmeValidationV1 extension"))); err != nil { - return nil, err - } - return tc, nil -} - -// dns01Challenge represents an dns-01 acme challenge. -type dns01Challenge struct { - *baseChallenge -} - -// newDNS01Challenge returns a new acme dns-01 challenge. -func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { - bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) - if err != nil { - return nil, err - } - bc.Type = "dns-01" - bc.Value = ops.Identifier.Value - - dc := &dns01Challenge{bc} - if err := dc.save(db, nil); err != nil { - return nil, err - } - return dc, nil -} - -// KeyAuthorization creates the ACME key authorization value from a token -// and a jwk. -func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { - thumbprint, err := jwk.Thumbprint(crypto.SHA256) - if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) - } - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - return fmt.Sprintf("%s.%s", token, encPrint), nil -} - -// validate attempts to validate the challenge. If the challenge has been -// satisfactorily validated, the 'status' and 'validated' attributes are -// updated. -func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - // If already valid or invalid then return without performing validation. - if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid { - return dc, nil - } - - // Normalize domain for wildcard DNS names - // This is done to avoid making TXT lookups for domains like - // _acme-challenge.*.example.com - // Instead perform txt lookup for _acme-challenge.example.com - domain := strings.TrimPrefix(dc.Value, "*.") - - txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) - if err != nil { - if err = dc.storeError(db, - DNSErr(errors.Wrapf(err, "error looking up TXT "+ - "records for domain %s", domain))); err != nil { - return nil, err - } - return dc, nil - } - - expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) - if err != nil { - return nil, err - } - h := sha256.Sum256([]byte(expectedKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) - var found bool - for _, r := range txtRecords { - if r == expected { - found = true - break - } - } - if !found { - if err = dc.storeError(db, - RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ - "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil { - return nil, err - } - return dc, nil - } - - // Update and store the challenge. - upd := &dns01Challenge{dc.baseChallenge.clone()} - upd.Status = StatusValid - upd.Error = nil - upd.Validated = time.Now().UTC() - - if err := upd.save(db, dc); err != nil { - return nil, err - } - return upd, nil -} - -// getChallenge retrieves and unmarshals an ACME challenge type from the database. -func getChallenge(db nosql.DB, id string) (challenge, error) { - b, err := db.Get(challengeTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id)) - } - ch, err := unmarshalChallenge(b) - if err != nil { - return nil, err - } - return ch, nil -} diff --git a/acme/db/db.go b/acme/db/db.go index 2882fcf4..19449d76 100644 --- a/acme/db/db.go +++ b/acme/db/db.go @@ -4,28 +4,28 @@ import "context" // DB is the DB interface expected by the step-ca ACME API. type DB interface { - CreateAccount(ctx context.Context, acc *Account) (*Account, error) - GetAccount(ctx context.Context, id string) (*Account, error) - GetAccountByKeyID(ctx context.Context) (*Account, error) - UpdateAccount(ctx context.Context, acc *Account) error + CreateAccount(ctx context.Context, acc *types.Account) (*types.Account, error) + GetAccount(ctx context.Context, id string) (*types.Account, error) + GetAccountByKeyID(ctx context.Context, kid string) (*types.Account, error) + UpdateAccount(ctx context.Context, acc *types.Account) error - CreateNonce(ctx context.Context) (Nonce, error) - DeleteNonce(ctx context.Context, nonce Nonce) error + CreateNonce(ctx context.Context) (types.Nonce, error) + DeleteNonce(ctx context.Context, nonce types.Nonce) error - CreateAuthorization(ctx context.Context, authz *Authorization) error - GetAuthorization(ctx context.Context, id string) (*Authorization, error) - UpdateAuthorization(ctx context.Context, authz *Authorization) error + CreateAuthorization(ctx context.Context, authz *types.Authorization) error + GetAuthorization(ctx context.Context, id string) (*types.Authorization, error) + UpdateAuthorization(ctx context.Context, authz *types.Authorization) error - CreateCertificate(ctx context.Context, cert *Certificate) error - GetCertificate(ctx context.Context, id string) (*Certificate, error) + CreateCertificate(ctx context.Context, cert *types.Certificate) error + GetCertificate(ctx context.Context, id string) (*types.Certificate, error) - CreateChallenge(ctx context.Context, ch *Challenge) error - GetChallenge(ctx context.Context, id string) (*Challenge, error) - UpdateChallenge(ctx context.Context, ch *Challenge) error + CreateChallenge(ctx context.Context, ch *types.Challenge) error + GetChallenge(ctx context.Context, id, authzID string) (*types.Challenge, error) + UpdateChallenge(ctx context.Context, ch *types.Challenge) error - CreateOrder(ctx context.Context, o *Order) error + CreateOrder(ctx context.Context, o *types.Order) error DeleteOrder(ctx context.Context, id string) error - GetOrder(ctx context.Context, id string) (*Order, error) + GetOrder(ctx context.Context, id string) (*types.Order, error) GetOrdersByAccountID(ctx context.Context, accountID string) error - UpdateOrder(ctx context.Context, o *Order) error + UpdateOrder(ctx context.Context, o *types.Order) error } diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index dcbfd2f5..dd93c5d2 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -20,7 +20,12 @@ type dbAccount struct { Status string `json:"status"` } -func (db *DB) saveAccount(nu *dbAccount, old *dbAccount) error { +func (dba *dbAccount) clone() *dbAccount { + nu := *dba + return &nu +} + +func (db *DB) saveDBAccount(nu *dbAccount, old *dbAccount) error { var ( err error oldB []byte @@ -51,7 +56,7 @@ func (db *DB) saveAccount(nu *dbAccount, old *dbAccount) error { } // CreateAccount imlements the AcmeDB.CreateAccount interface. -func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { +func (db *DB) CreateAccount(ctx context.Context, acc *types.Account) error { id, err := randID() if err != nil { return nil, err @@ -79,7 +84,7 @@ func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { case !swapped: return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) default: - if err = db.saveAccount(dba, nil); err != nil { + if err = db.saveDBAccount(dba, nil); err != nil { db.db.Del(accountByKeyIDTable, kidB) return err } @@ -87,52 +92,76 @@ func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { } } +// GetAccount retrieves an ACME account by ID. +func (db *DB) GetAccount(ctx context.Context, id string) (*types.Account, error) { + + return &types.Account{ + Status: dbacc.Status, + Contact: dbacc.Contact, + Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), + Key: dbacc.Key, + ID: dbacc.ID, + }, nil +} + +// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). +func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*types.Account, error) { + id, err := db.getAccountIDByKeyID(kid) + if err != nil { + return nil, err + } + return db.GetAccount(ctx, id) +} + // UpdateAccount imlements the AcmeDB.UpdateAccount interface. -func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error { - kid, err := keyToID(dba.Key) +func (db *DB) UpdateAccount(ctx context.Context, acc *types.Account) error { + kid := "from-context" + + old, err := db.getDBAccountByKeyID(ctx, kid) if err != nil { return err } - dba, err := db.db.getAccountByKeyID(ctx, kid) - - newdba := *dba - newdba.Contact = acc.contact - newdba.Status = acc.Status + nu := old.clone() + nu.Contact = acc.contact + nu.Status = acc.Status // If the status has changed to 'deactivated', then set deactivatedAt timestamp. - if acc.Status == types.StatusDeactivated && dba.Status != types.Status.Deactivated { - newdba.Deactivated = clock.Now() + if acc.Status == types.StatusDeactivated && old.Status != types.Status.Deactivated { + nu.Deactivated = clock.Now() } - return db.saveAccount(newdba, dba) + return db.saveDBAccount(newdba, dba) } -// getAccountByID retrieves the account with the given ID. -func (db *DB) getAccountByID(ctx context.Context, id string) (*dbAccount, error) { - ab, err := db.db.Get(accountTable, []byte(id)) +func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { + id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) + return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) - } - - a := new(account) - if err = json.Unmarshal(ab, a); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) + return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) } - return a, nil + return string(id), nil } -// getAccountByKeyID retrieves Id associated with the given Kid. -func (db *DB) getAccountByKeyID(ctx context.Context, kid string) (*dbAccount, error) { - id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) +// getDBAccountByKeyID retrieves Id associated with the given Kid. +func (db *DB) getDBAccountByKeyID(ctx context.Context, kid string) (*dbAccount, error) { + id, err := db.getAccountIDByKeyID(ctx, kid) + if err != nil { + return err + } + data, err := db.db.Get(accountTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) + return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) + return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) + } + + dbacc := new(account) + if err = json.Unmarshal(data, dbacc); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) } - return getAccountByID(db, string(id)) + return dbacc, nil } diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index e69de29b..32e42a69 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -0,0 +1,216 @@ +package nosql + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql" +) + +var defaultExpiryDuration = time.Hour * 24 + +// dbAuthz is the base authz type that others build from. +type dbAuthz struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier Identifier `json:"identifier"` + Status string `json:"status"` + Expires time.Time `json:"expires"` + Challenges []string `json:"challenges"` + Wildcard bool `json:"wildcard"` + Created time.Time `json:"created"` + Error *Error `json:"error"` +} + +func (ba *dbAuthz) clone() *dbAuthz { + u := *ba + return &u +} + +func (db *DB) saveDBAuthz(ctx context.Context, nu *authz, old *dbAuthz) error { + var ( + err error + oldB, newB []byte + ) + + if old == nil { + oldB = nil + } else { + if oldB, err = json.Marshal(old); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old authz")) + } + } + if newB, err = json.Marshal(nu); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling new authz")) + } + _, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error storing authz")) + case !swapped: + return ServerInternalErr(errors.Errorf("error storing authz; " + + "value has changed since last read")) + default: + return nil + } +} + +// getDBAuthz retrieves and unmarshals a database representation of the +// ACME Authorization type. +func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) { + data, err := db.db.Get(authzTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id)) + } + + var dbaz dbAuthz + if err = json.Unmarshal(data, &dbaz); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dbAuthz")) + } + return &dbaz +} + +// GetAuthorization retrieves and unmarshals an ACME authz type from the database. +// Implements acme.DB GetAuthorization interface. +func (db *DB) GetAuthorization(ctx context.Context, id string) (*types.Authorization, error) { + dbaz, err := getDBAuthz(id) + if err != nil { + return nil, err + } + var chs = make([]*Challenge, len(ba.Challenges)) + for i, chID := range dbaz.Challenges { + chs[i], err = db.GetChallenge(ctx, chID) + if err != nil { + return nil, err + } + } + return &types.Authorization{ + Identifier: dbaz.Identifier, + Status: dbaz.Status, + Challenges: chs, + Wildcard: dbaz.Wildcard, + Expires: dbaz.Expires.Format(time.RFC3339), + ID: dbaz.ID, + }, nil +} + +// CreateAuthorization creates an entry in the database for the Authorization. +// Implements the acme.DB.CreateAuthorization interface. +func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) error { + if len(authz.AccountID) == 0 { + return ServerInternalErr(errors.New("AccountID cannot be empty")) + } + az.ID, err = randID() + if err != nil { + return nil, err + } + + now := clock.Now() + dbaz := &dbAuthz{ + ID: az.ID, + AccountID: az.AccountId, + Status: types.StatusPending, + Created: now, + Expires: now.Add(defaultExpiryDuration), + Identifier: az.Identifier, + } + + if strings.HasPrefix(az.Identifier.Value, "*.") { + dbaz.Wildcard = true + dbaz.Identifier = Identifier{ + Value: strings.TrimPrefix(identifier.Value, "*."), + Type: identifier.Type, + } + } + + chIDs := []string{} + chTypes := []string{"dns-01"} + // HTTP and TLS challenges can only be used for identifiers without wildcards. + if !dbaz.Wildcard { + chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) + } + + for _, typ := range chTypes { + ch, err := db.CreateChallenge(ctx, &types.Challenge{ + AccountID: az.AccountID, + AuthzID: az.ID, + Value: az.Identifier.Value, + Type: typ, + }) + if err != nil { + return nil, Wrapf(err, "error creating '%s' challenge", typ) + } + + chIDs = append(chIDs, ch.ID) + } + dbaz.Challenges = chIDs + + return db.saveDBAuthz(ctx, dbaz, nil) +} + +// UpdateAuthorization saves an updated ACME Authorization to the database. +func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization) error { + old, err := db.getDBAuthz(ctx, az.ID) + if err != nil { + return err + } + + nu := old.clone() + + nu.Status = az.Status + nu.Error = az.Error + return db.saveDBAuthz(ctx, nu, old) +} + +/* +// updateStatus attempts to update the status on a dbAuthz and stores the +// updating object if necessary. +func (ba *dbAuthz) updateStatus(db nosql.DB) (authz, error) { + newAuthz := ba.clone() + + now := time.Now().UTC() + switch ba.Status { + case StatusInvalid: + return ba.parent(), nil + case StatusValid: + return ba.parent(), nil + case StatusPending: + // check expiry + if now.After(ba.Expires) { + newAuthz.Status = StatusInvalid + newAuthz.Error = MalformedErr(errors.New("authz has expired")) + break + } + + var isValid = false + for _, chID := range ba.Challenges { + ch, err := getChallenge(db, chID) + if err != nil { + return ba, err + } + if ch.getStatus() == StatusValid { + isValid = true + break + } + } + + if !isValid { + return ba.parent(), nil + } + newAuthz.Status = StatusValid + newAuthz.Error = nil + default: + return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) + } + + if err := newAuthz.save(db, ba); err != nil { + return ba, err + } + return newAuthz.parent(), nil +} +*/ diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index e69de29b..6303f005 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -0,0 +1,528 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql" +) + +// ChallengeOptions is the type used to created a new Challenge. +type ChallengeOptions struct { + AccountID string + AuthzID string + Identifier Identifier +} + +// dbChallenge is the base Challenge type that others build from. +type dbChallenge struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + AuthzID string `json:"authzID"` + Type string `json:"type"` + Status string `json:"status"` + Token string `json:"token"` + Value string `json:"value"` + Validated time.Time `json:"validated"` + Created time.Time `json:"created"` + Error *AError `json:"error"` +} + +func (dbc *dbChallenge) clone() *dbChallenge { + u := *bc + return &u +} + +// save writes the challenge to disk. For new challenges 'old' should be nil, +// otherwise 'old' should be a pointer to the acme challenge as it was at the +// start of the request. This method will fail if the value currently found +// in the bucket/row does not match the value of 'old'. +func (db *DB) saveDBChallenge(ctx context.Context, nu challenge, old challenge) error { + newB, err := json.Marshal(nu) + if err != nil { + return ServerInternalErr(errors.Wrap(err, + "error marshaling new acme challenge")) + } + var oldB []byte + if old == nil { + oldB = nil + } else { + oldB, err = json.Marshal(old) + if err != nil { + return ServerInternalErr(errors.Wrap(err, + "error marshaling old acme challenge")) + } + } + + _, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error saving acme challenge")) + case !swapped: + return ServerInternalErr(errors.New("error saving acme challenge; " + + "acme challenge has changed since last read")) + default: + return nil + } +} + +func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { + data, err := db.db.Get(challengeTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id)) + } + + dbch := new(baseChallenge) + if err := json.Unmarshal(data, dbch); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ + "challenge type into dbChallenge")) + } + return dbch +} + +// CreateChallenge creates a new ACME challenge data structure in the database. +// Implements acme.DB.CreateChallenge interface. +func (db *DB) CreateChallenge(ctx context.context, ch *types.Challenge) error { + if len(ch.AuthzID) == 0 { + return ServerInternalError(errors.New("AuthzID cannot be empty")) + } + if len(ch.AccountID) == 0 { + return ServerInternalError(errors.New("AccountID cannot be empty")) + } + if len(ch.Value) == 0 { + return ServerInternalError(errors.New("AccountID cannot be empty")) + } + // TODO: verify that challenge type is set and is one of expected types. + if len(ch.Type) == 0 { + return ServerInternalError(errors.New("Type cannot be empty")) + } + + ch.ID, err = randID() + if err != nil { + return nil, Wrap(err, "error generating random id for ACME challenge") + } + ch.Token, err = randID() + if err != nil { + return nil, Wrap(err, "error generating token for ACME challenge") + } + + dbch := &dbChallenge{ + ID: ch.ID, + AuthzID: ch.AuthzID, + AccountID: ch.AccountID, + Value: ch.Value, + Status: types.StatusPending, + Token: ch.Token, + Created: clock.Now(), + Type: ch.Type, + } + + return dbch.saveDBChallenge(ctx, dbch, nil) +} + +// GetChallenge retrieves and unmarshals an ACME challenge type from the database. +// Implements the acme.DB GetChallenge interface. +func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*types.Challenge, error) { + dbch, err := db.getDBChallenge(ctx, id) + if err != nil { + return err + } + + ch := &Challenge{ + Type: dbch.Type, + Status: dbch.Status, + Token: dbch.Token, + URL: dir.getLink(ctx, ChallengeLink, true, dbch.getID()), + ID: dbch.ID, + AuthzID: dbch.AuthzID(), + Error: dbch.Error, + } + if !dbch.Validated.IsZero() { + ac.Validated = dbch.Validated.Format(time.RFC3339) + } + return ch, nil +} + +// UpdateChallenge updates an ACME challenge type in the database. +func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error { + old, err := db.getDBChallenge(ctx, id) + if err != nil { + return err + } + + nu := old.clone() + + // These should be the only values chaning in an Update request. + nu.Status = ch.Status + nu.Error = ch.Error + if nu.Status == types.StatusValid { + nu.Validated = clock.Now() + } + + return db.saveDBChallenge(ctx, nu, old) +} + +//// http01Challenge represents an http-01 acme challenge. +//type http01Challenge struct { +// *baseChallenge +//} +// +//// newHTTP01Challenge returns a new acme http-01 challenge. +//func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { +// bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) +// if err != nil { +// return nil, err +// } +// bc.Type = "http-01" +// bc.Value = ops.Identifier.Value +// +// hc := &http01Challenge{bc} +// if err := hc.save(db, nil); err != nil { +// return nil, err +// } +// return hc, nil +//} +// +//// Validate attempts to validate the challenge. If the challenge has been +//// satisfactorily validated, the 'status' and 'validated' attributes are +//// updated. +//func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { +// // If already valid or invalid then return without performing validation. +// if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid { +// return hc, nil +// } +// url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token) +// +// resp, err := vo.httpGet(url) +// if err != nil { +// if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err, +// "error doing http GET for url %s", url))); err != nil { +// return nil, err +// } +// return hc, nil +// } +// if resp.StatusCode >= 400 { +// if err = hc.storeError(db, +// ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", +// url, resp.StatusCode))); err != nil { +// return nil, err +// } +// return hc, nil +// } +// defer resp.Body.Close() +// +// body, err := ioutil.ReadAll(resp.Body) +// if err != nil { +// return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+ +// "response body for url %s", url)) +// } +// keyAuth := strings.Trim(string(body), "\r\n") +// +// expected, err := KeyAuthorization(hc.Token, jwk) +// if err != nil { +// return nil, err +// } +// if keyAuth != expected { +// if err = hc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ +// "expected %s, but got %s", expected, keyAuth))); err != nil { +// return nil, err +// } +// return hc, nil +// } +// +// // Update and store the challenge. +// upd := &http01Challenge{hc.baseChallenge.clone()} +// upd.Status = StatusValid +// upd.Error = nil +// upd.Validated = clock.Now() +// +// if err := upd.save(db, hc); err != nil { +// return nil, err +// } +// return upd, nil +//} +// +//type tlsALPN01Challenge struct { +// *baseChallenge +//} +// +//// newTLSALPN01Challenge returns a new acme tls-alpn-01 challenge. +//func newTLSALPN01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { +// bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) +// if err != nil { +// return nil, err +// } +// bc.Type = "tls-alpn-01" +// bc.Value = ops.Identifier.Value +// +// hc := &tlsALPN01Challenge{bc} +// if err := hc.save(db, nil); err != nil { +// return nil, err +// } +// return hc, nil +//} +// +//func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { +// // If already valid or invalid then return without performing validation. +// if tc.getStatus() == StatusValid || tc.getStatus() == StatusInvalid { +// return tc, nil +// } +// +// config := &tls.Config{ +// NextProtos: []string{"acme-tls/1"}, +// ServerName: tc.Value, +// InsecureSkipVerify: true, // we expect a self-signed challenge certificate +// } +// +// hostPort := net.JoinHostPort(tc.Value, "443") +// +// conn, err := vo.tlsDial("tcp", hostPort, config) +// if err != nil { +// if err = tc.storeError(db, +// ConnectionErr(errors.Wrapf(err, "error doing TLS dial for %s", hostPort))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// defer conn.Close() +// +// cs := conn.ConnectionState() +// certs := cs.PeerCertificates +// +// if len(certs) == 0 { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("%s challenge for %s resulted in no certificates", +// tc.Type, tc.Value))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("cannot negotiate ALPN acme-tls/1 protocol for "+ +// "tls-alpn-01 challenge"))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// leafCert := certs[0] +// +// if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ +// "leaf certificate must contain a single DNS name, %v", tc.Value))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} +// idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} +// foundIDPeAcmeIdentifierV1Obsolete := false +// +// keyAuth, err := KeyAuthorization(tc.Token, jwk) +// if err != nil { +// return nil, err +// } +// hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) +// +// for _, ext := range leafCert.Extensions { +// if idPeAcmeIdentifier.Equal(ext.Id) { +// if !ext.Critical { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ +// "acmeValidationV1 extension not critical"))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// var extValue []byte +// rest, err := asn1.Unmarshal(ext.Value, &extValue) +// +// if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ +// "malformed acmeValidationV1 extension value"))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ +// "expected acmeValidationV1 extension value %s for this challenge but got %s", +// hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// upd := &tlsALPN01Challenge{tc.baseChallenge.clone()} +// upd.Status = StatusValid +// upd.Error = nil +// upd.Validated = clock.Now() +// +// if err := upd.save(db, tc); err != nil { +// return nil, err +// } +// return upd, nil +// } +// +// if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { +// foundIDPeAcmeIdentifierV1Obsolete = true +// } +// } +// +// if foundIDPeAcmeIdentifierV1Obsolete { +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ +// "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))); err != nil { +// return nil, err +// } +// return tc, nil +// } +// +// if err = tc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ +// "missing acmeValidationV1 extension"))); err != nil { +// return nil, err +// } +// return tc, nil +//} +// +//// dns01Challenge represents an dns-01 acme challenge. +//type dns01Challenge struct { +// *baseChallenge +//} +// +//// newDNS01Challenge returns a new acme dns-01 challenge. +//func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { +// bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) +// if err != nil { +// return nil, err +// } +// bc.Type = "dns-01" +// bc.Value = ops.Identifier.Value +// +// dc := &dns01Challenge{bc} +// if err := dc.save(db, nil); err != nil { +// return nil, err +// } +// return dc, nil +//} +// +//// KeyAuthorization creates the ACME key authorization value from a token +//// and a jwk. +//func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { +// thumbprint, err := jwk.Thumbprint(crypto.SHA256) +// if err != nil { +// return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) +// } +// encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) +// return fmt.Sprintf("%s.%s", token, encPrint), nil +//} +// +//// validate attempts to validate the challenge. If the challenge has been +//// satisfactorily validated, the 'status' and 'validated' attributes are +//// updated. +//func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { +// // If already valid or invalid then return without performing validation. +// if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid { +// return dc, nil +// } +// +// // Normalize domain for wildcard DNS names +// // This is done to avoid making TXT lookups for domains like +// // _acme-challenge.*.example.com +// // Instead perform txt lookup for _acme-challenge.example.com +// domain := strings.TrimPrefix(dc.Value, "*.") +// +// txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) +// if err != nil { +// if err = dc.storeError(db, +// DNSErr(errors.Wrapf(err, "error looking up TXT "+ +// "records for domain %s", domain))); err != nil { +// return nil, err +// } +// return dc, nil +// } +// +// expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) +// if err != nil { +// return nil, err +// } +// h := sha256.Sum256([]byte(expectedKeyAuth)) +// expected := base64.RawURLEncoding.EncodeToString(h[:]) +// var found bool +// for _, r := range txtRecords { +// if r == expected { +// found = true +// break +// } +// } +// if !found { +// if err = dc.storeError(db, +// RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ +// "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil { +// return nil, err +// } +// return dc, nil +// } +// +// // Update and store the challenge. +// upd := &dns01Challenge{dc.baseChallenge.clone()} +// upd.Status = StatusValid +// upd.Error = nil +// upd.Validated = time.Now().UTC() +// +// if err := upd.save(db, dc); err != nil { +// return nil, err +// } +// return upd, nil +//} +// +//// unmarshalChallenge unmarshals a challenge type into the correct sub-type. +//func unmarshalChallenge(data []byte) (challenge, error) { +// var getType struct { +// Type string `json:"type"` +// } +// if err := json.Unmarshal(data, &getType); err != nil { +// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type")) +// } +// +// switch getType.Type { +// case "dns-01": +// var bc baseChallenge +// if err := json.Unmarshal(data, &bc); err != nil { +// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ +// "challenge type into dns01Challenge")) +// } +// return &dns01Challenge{&bc}, nil +// case "http-01": +// var bc baseChallenge +// if err := json.Unmarshal(data, &bc); err != nil { +// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ +// "challenge type into http01Challenge")) +// } +// return &http01Challenge{&bc}, nil +// case "tls-alpn-01": +// var bc baseChallenge +// if err := json.Unmarshal(data, &bc); err != nil { +// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ +// "challenge type into tlsALPN01Challenge")) +// } +// return &tlsALPN01Challenge{&bc}, nil +// default: +// return nil, ServerInternalErr(errors.Errorf("unexpected challenge type '%s'", getType.Type)) +// } +//} +// diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index e69de29b..6408cd00 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -0,0 +1,469 @@ +package nosql + +import ( + "context" + "crypto/x509" + "encoding/json" + "sort" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/nosql" + "go.step.sm/crypto/x509util" +) + +var defaultOrderExpiry = time.Hour * 24 + +// Mutex for locking ordersByAccount index operations. +var ordersByAccountMux sync.Mutex + +// OrderOptions options with which to create a new Order. +type OrderOptions struct { + AccountID string `json:"accID"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` + backdate time.Duration + defaultDuration time.Duration +} + +type dbOrder struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires,omitempty"` + Status string `json:"status"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + Error *Error `json:"error,omitempty"` + Authorizations []string `json:"authorizations"` + Certificate string `json:"certificate,omitempty"` +} + +// getDBOrder retrieves and unmarshals an ACME Order type from the database. +func (db *DB) getDBOrder(id string) (*dbOrder, error) { + b, err := db.db.Get(orderTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id)) + } + o := new(dbOrder) + if err := json.Unmarshal(b, &o); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order")) + } + return o, nil +} + +// GetOrder retrieves an ACME Order from the database. +func (db *DB) GetOrder(id string) (*types.Order, error) { + dbo, err := db.getDBOrder(id) + + azs := make([]string, len(dbo.Authorizations)) + for i, aid := range dbo.Authorizations { + azs[i] = dir.getLink(ctx, AuthzLink, true, aid) + } + o := &Order{ + Status: dbo.Status, + Expires: dbo.Expires.Format(time.RFC3339), + Identifiers: dbo.Identifiers, + NotBefore: dbo.NotBefore.Format(time.RFC3339), + NotAfter: dbo.NotAfter.Format(time.RFC3339), + Authorizations: azs, + Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), + ID: dbo.ID, + } + + if dbo.Certificate != "" { + o.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate) + } + return o, nil +} + +// newOrder returns a new Order type. +func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { + id, err := randID() + if err != nil { + return nil, err + } + + authzs := make([]string, len(ops.Identifiers)) + for i, identifier := range ops.Identifiers { + az, err := newAuthz(db, ops.AccountID, identifier) + if err != nil { + return nil, err + } + authzs[i] = az.getID() + } + + now := clock.Now() + var backdate time.Duration + nbf := ops.NotBefore + if nbf.IsZero() { + nbf = now + backdate = -1 * ops.backdate + } + naf := ops.NotAfter + if naf.IsZero() { + naf = nbf.Add(ops.defaultDuration) + } + + o := &order{ + ID: id, + AccountID: ops.AccountID, + Created: now, + Status: StatusPending, + Expires: now.Add(defaultOrderExpiry), + Identifiers: ops.Identifiers, + NotBefore: nbf.Add(backdate), + NotAfter: naf, + Authorizations: authzs, + } + if err := o.save(db, nil); err != nil { + return nil, err + } + + var oidHelper = orderIDsByAccount{} + _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID) + if err != nil { + return nil, err + } + return o, nil +} + +type orderIDsByAccount struct{} + +// addOrderID adds an order ID to a users index of in progress order IDs. +// This method will also cull any orders that are no longer in the `pending` +// state from the index before returning it. +func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { + ordersByAccountMux.Lock() + defer ordersByAccountMux.Unlock() + + // Update the "order IDs by account ID" index + oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID) + if err != nil { + return nil, err + } + newOids := append(oids, oid) + if err = orderIDs(newOids).save(db, oids, accID); err != nil { + // Delete the entire order if storing the index fails. + db.Del(orderTable, []byte(oid)) + return nil, err + } + return newOids, nil +} + +// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the +// account. +func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { + b, err := db.Get(ordersByAccountIDTable, []byte(accID)) + if err != nil { + if nosql.IsErrNotFound(err) { + return []string{}, nil + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) + } + var oids []string + if err := json.Unmarshal(b, &oids); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) + } + + // Remove any order that is not in PENDING state and update the stored list + // before returning. + // + // According to RFC 8555: + // The server SHOULD include pending orders and SHOULD NOT include orders + // that are invalid in the array of URLs. + pendOids := []string{} + for _, oid := range oids { + o, err := getOrder(db, oid) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) + } + if o, err = o.updateStatus(db); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) + } + if o.Status == StatusPending { + pendOids = append(pendOids, oid) + } + } + // If the number of pending orders is less than the number of orders in the + // list, then update the pending order list. + if len(pendOids) != len(oids) { + if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ + "len(orderIDs) = %d", len(pendOids))) + } + } + + return pendOids, nil +} + +type orderIDs []string + +// save is used to update the list of orderIDs keyed by ACME account ID +// stored in the database. +// +// This method always converts empty lists to 'nil' when storing to the DB. We +// do this to avoid any confusion between an empty list and a nil value in the +// db. +func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { + var ( + err error + oldb []byte + newb []byte + ) + if len(old) == 0 { + oldb = nil + } else { + oldb, err = json.Marshal(old) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) + } + } + if len(oids) == 0 { + newb = nil + } else { + newb, err = json.Marshal(oids) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) + } + } + _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) + switch { + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) + case !swapped: + return ServerInternalErr(errors.Errorf("error storing order IDs "+ + "for account %s; order IDs changed since last read", accID)) + default: + return nil + } +} + +func (o *order) save(db nosql.DB, old *order) error { + var ( + err error + oldB []byte + ) + if old == nil { + oldB = nil + } else { + if oldB, err = json.Marshal(old); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) + } + } + + newB, err := json.Marshal(o) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order")) + } + + _, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error storing order")) + case !swapped: + return ServerInternalErr(errors.New("error storing order; " + + "value has changed since last read")) + default: + return nil + } +} + +// updateStatus updates order status if necessary. +func (o *order) updateStatus(db nosql.DB) (*order, error) { + _newOrder := *o + newOrder := &_newOrder + + now := time.Now().UTC() + switch o.Status { + case StatusInvalid: + return o, nil + case StatusValid: + return o, nil + case StatusReady: + // check expiry + if now.After(o.Expires) { + newOrder.Status = StatusInvalid + newOrder.Error = MalformedErr(errors.New("order has expired")) + break + } + return o, nil + case StatusPending: + // check expiry + if now.After(o.Expires) { + newOrder.Status = StatusInvalid + newOrder.Error = MalformedErr(errors.New("order has expired")) + break + } + + var count = map[string]int{ + StatusValid: 0, + StatusInvalid: 0, + StatusPending: 0, + } + for _, azID := range o.Authorizations { + az, err := getAuthz(db, azID) + if err != nil { + return nil, err + } + if az, err = az.updateStatus(db); err != nil { + return nil, err + } + st := az.getStatus() + count[st]++ + } + switch { + case count[StatusInvalid] > 0: + newOrder.Status = StatusInvalid + + // No change in the order status, so just return the order as is - + // without writing any changes. + case count[StatusPending] > 0: + return newOrder, nil + + case count[StatusValid] == len(o.Authorizations): + newOrder.Status = StatusReady + + default: + return nil, ServerInternalErr(errors.New("unexpected authz status")) + } + default: + return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) + } + + if err := newOrder.save(db, o); err != nil { + return nil, err + } + return newOrder, nil +} + +// finalize signs a certificate if the necessary conditions for Order completion +// have been met. +func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) { + var err error + if o, err = o.updateStatus(db); err != nil { + return nil, err + } + + switch o.Status { + case StatusInvalid: + return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) + case StatusValid: + return o, nil + case StatusPending: + return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) + case StatusReady: + break + default: + return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) + } + + // RFC8555: The CSR MUST indicate the exact same set of requested + // identifiers as the initial newOrder request. Identifiers of type "dns" + // MUST appear either in the commonName portion of the requested subject + // name or in an extensionRequest attribute [RFC2985] requesting a + // subjectAltName extension, or both. + if csr.Subject.CommonName != "" { + csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) + } + csr.DNSNames = uniqueLowerNames(csr.DNSNames) + orderNames := make([]string, len(o.Identifiers)) + for i, n := range o.Identifiers { + orderNames[i] = n.Value + } + orderNames = uniqueLowerNames(orderNames) + + // Validate identifier names against CSR alternative names. + // + // Note that with certificate templates we are not going to check for the + // absence of other SANs as they will only be set if the templates allows + // them. + if len(csr.DNSNames) != len(orderNames) { + return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + } + + sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) + for i := range csr.DNSNames { + if csr.DNSNames[i] != orderNames[i] { + return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + } + sans[i] = x509util.SubjectAlternativeName{ + Type: x509util.DNSType, + Value: csr.DNSNames[i], + } + } + + // Get authorizations from the ACME provisioner. + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + signOps, err := p.AuthorizeSign(ctx, "") + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) + } + + // Template data + data := x509util.NewTemplateData() + data.SetCommonName(csr.Subject.CommonName) + data.Set(x509util.SANsKey, sans) + + templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) + } + signOps = append(signOps, templateOptions) + + // Create and store a new certificate. + certChain, err := auth.Sign(csr, provisioner.SignOptions{ + NotBefore: provisioner.NewTimeDuration(o.NotBefore), + NotAfter: provisioner.NewTimeDuration(o.NotAfter), + }, signOps...) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) + } + + cert, err := newCert(db, CertOptions{ + AccountID: o.AccountID, + OrderID: o.ID, + Leaf: certChain[0], + Intermediates: certChain[1:], + }) + if err != nil { + return nil, err + } + + _newOrder := *o + newOrder := &_newOrder + newOrder.Certificate = cert.ID + newOrder.Status = StatusValid + if err := newOrder.save(db, o); err != nil { + return nil, err + } + return newOrder, nil +} + +// toACME converts the internal Order type into the public acmeOrder type for +// presentation in the ACME protocol. +func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) { +} + +// uniqueLowerNames returns the set of all unique names in the input after all +// of them are lowercased. The returned names will be in their lowercased form +// and sorted alphabetically. +func uniqueLowerNames(names []string) (unique []string) { + nameMap := make(map[string]int, len(names)) + for _, name := range names { + nameMap[strings.ToLower(name)] = 1 + } + unique = make([]string, 0, len(nameMap)) + for name := range nameMap { + unique = append(unique, name) + } + sort.Strings(unique) + return +} diff --git a/acme/types/authz.go b/acme/types/authz.go index 3e3a5aa7..4119f6c1 100644 --- a/acme/types/authz.go +++ b/acme/types/authz.go @@ -6,15 +6,15 @@ import ( "github.com/pkg/errors" ) -// Authz is a subset of the Authz type containing only those attributes -// required for responses in the ACME protocol. -type Authz struct { +// Authorization representst an ACME Authorization. +type Authorization struct { Identifier Identifier `json:"identifier"` Status string `json:"status"` Expires string `json:"expires"` Challenges []*Challenge `json:"challenges"` Wildcard bool `json:"wildcard"` ID string `json:"-"` + AccountID string `json:"-"` } // ToLog enables response logging. @@ -25,8 +25,3 @@ func (a *Authz) ToLog() (interface{}, error) { } return string(b), nil } - -// GetID returns the Authz ID. -func (a *Authz) GetID() string { - return a.ID -} diff --git a/acme/types/challenge.go b/acme/types/challenge.go new file mode 100644 index 00000000..61bcd2fb --- /dev/null +++ b/acme/types/challenge.go @@ -0,0 +1,30 @@ +package types + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +// Challenge represents an ACME response Challenge type. +type Challenge struct { + Type string `json:"type"` + Status string `json:"status"` + Token string `json:"token"` + Validated string `json:"validated,omitempty"` + URL string `json:"url"` + Error *AError `json:"error,omitempty"` + ID string `json:"-"` + AuthzID string `json:"-"` + AccountID string `json:"-"` + Value string `json:"-"` +} + +// ToLog enables response logging. +func (c *Challenge) ToLog() (interface{}, error) { + b, err := json.Marshal(c) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) + } + return string(b), nil +} diff --git a/acme/types/order.go b/acme/types/order.go new file mode 100644 index 00000000..f14dc0bb --- /dev/null +++ b/acme/types/order.go @@ -0,0 +1,30 @@ +package types + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +// Order contains order metadata for the ACME protocol order type. +type Order struct { + Status string `json:"status"` + Expires string `json:"expires,omitempty"` + Identifiers []Identifier `json:"identifiers"` + NotBefore string `json:"notBefore,omitempty"` + NotAfter string `json:"notAfter,omitempty"` + Error interface{} `json:"error,omitempty"` + Authorizations []string `json:"authorizations"` + Finalize string `json:"finalize"` + Certificate string `json:"certificate,omitempty"` + ID string `json:"-"` +} + +// ToLog enables response logging. +func (o *Order) ToLog() (interface{}, error) { + b, err := json.Marshal(o) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging")) + } + return string(b), nil +} diff --git a/acme/types/status.go b/acme/types/status.go new file mode 100644 index 00000000..c98a506e --- /dev/null +++ b/acme/types/status.go @@ -0,0 +1,20 @@ +package types + +// Status represents an ACME status. +type Status string + +var ( + // StatusValid -- valid + StatusValid = Status("valid") + // StatusInvalid -- invalid + StatusInvalid = Status("invalid") + // StatusPending -- pending; e.g. an Order that is not ready to be finalized. + StatusPending = Status("pending") + // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid. + StatusDeactivated = Status("deactivated") + // StatusReady -- ready; e.g. for an Order that is ready to be finalized. + StatusReady = Status("ready") + //statusExpired = "expired" + //statusActive = "active" + //statusProcessing = "processing" +) From 0368957e79244ac13e8a9365632076b5c629385b Mon Sep 17 00:00:00 2001 From: max furman Date: Fri, 26 Feb 2021 10:17:18 -0800 Subject: [PATCH 04/47] [acmedb] (wip) --- acme/db/nosql/order.go | 13 ++++++++----- acme/types/order.go | 1 + 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 6408cd00..5ba54790 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -84,9 +84,8 @@ func (db *DB) GetOrder(id string) (*types.Order, error) { return o, nil } -// newOrder returns a new Order type. -func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { - id, err := randID() +func (db *DB) CreateOrder(ctx context.Context, o *types.Order) error { + o.ID, err := randID() if err != nil { return nil, err } @@ -112,7 +111,7 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { naf = nbf.Add(ops.defaultDuration) } - o := &order{ + dbo := &dbOrder{ ID: id, AccountID: ops.AccountID, Created: now, @@ -123,7 +122,7 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { NotAfter: naf, Authorizations: authzs, } - if err := o.save(db, nil); err != nil { + if err := db.saveDBOrder(dbo, nil); err != nil { return nil, err } @@ -135,6 +134,10 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { return o, nil } +// newOrder returns a new Order type. +func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { +} + type orderIDsByAccount struct{} // addOrderID adds an order ID to a users index of in progress order IDs. diff --git a/acme/types/order.go b/acme/types/order.go index f14dc0bb..7e472204 100644 --- a/acme/types/order.go +++ b/acme/types/order.go @@ -18,6 +18,7 @@ type Order struct { Finalize string `json:"finalize"` Certificate string `json:"certificate,omitempty"` ID string `json:"-"` + ProvisionerID string `json:"-"` } // ToLog enables response logging. From 461bad3fefce23d0668284296e414ee206fdf863 Mon Sep 17 00:00:00 2001 From: max furman Date: Sat, 27 Feb 2021 17:05:37 -0800 Subject: [PATCH 05/47] [acme db interface] wip --- acme/{types => }/account.go | 0 acme/authority.go | 83 ++++---- acme/authorization.go | 75 +++++++ acme/challenge.go | 262 +++++++++++++++++++++++ acme/common.go | 11 + acme/{db => }/db.go | 0 acme/db/nosql/account.go | 52 +---- acme/db/nosql/authz.go | 110 ++-------- acme/db/nosql/challenge.go | 403 +----------------------------------- acme/db/nosql/nonce.go | 12 +- acme/db/nosql/nosql.go | 46 ++++ acme/db/nosql/order.go | 298 ++++---------------------- acme/errors.go | 3 + acme/{types => }/nonce.go | 0 acme/order.go | 371 +++++---------------------------- acme/{types => }/status.go | 0 acme/types/authz.go | 27 --- acme/types/challenge.go | 30 --- acme/types/order.go | 31 --- 19 files changed, 565 insertions(+), 1249 deletions(-) rename acme/{types => }/account.go (100%) create mode 100644 acme/authorization.go create mode 100644 acme/challenge.go rename acme/{db => }/db.go (100%) rename acme/{types => }/nonce.go (100%) rename acme/{types => }/status.go (100%) delete mode 100644 acme/types/authz.go delete mode 100644 acme/types/challenge.go delete mode 100644 acme/types/order.go diff --git a/acme/types/account.go b/acme/account.go similarity index 100% rename from acme/types/account.go rename to acme/account.go diff --git a/acme/authority.go b/acme/authority.go index c9190811..c0b6a732 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" + "log" "net" "net/http" "net/url" @@ -69,17 +70,6 @@ type AuthorityOptions struct { Prefix string } -var ( - accountTable = []byte("acme_accounts") - accountByKeyIDTable = []byte("acme_keyID_accountID_index") - authzTable = []byte("acme_authzs") - challengeTable = []byte("acme_challenges") - nonceTable = []byte("nonces") - orderTable = []byte("acme_orders") - ordersByAccountIDTable = []byte("acme_account_orders_index") - certTable = []byte("acme_certs") -) - // NewAuthority returns a new Authority that implements the ACME interface. // // Deprecated: NewAuthority exists for hitorical compatibility and should not @@ -197,14 +187,23 @@ func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) ( // GetOrder returns an ACME order. func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) { - o, err := getOrder(a.db, orderID) + prov, err := ProvisionerFromContext(ctx) if err != nil { return nil, err } + o, err := a.db.GetOrder(ctx, orderID) + if err != nil { + return nil, ServerInternalErr(err) + } if accID != o.AccountID { + log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) return nil, UnauthorizedErr(errors.New("account does not own order")) } - if o, err = o.updateStatus(a.db); err != nil { + if prov.GetID() != o.ProvisionerID { + log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) + return nil, UnauthorizedErr(errors.New("provisioner does not own order")) + } + if err = a.updateOrderStatus(ctx, o); err != nil { return nil, err } return o.toACME(ctx, a.db, a.dir) @@ -234,13 +233,15 @@ func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, err if err != nil { return nil, err } - ops.backdate = a.backdate.Duration - ops.defaultDuration = prov.DefaultTLSCertDuration() - order, err := newOrder(a.db, ops) - if err != nil { - return nil, Wrap(err, "error creating order") - } - return order.toACME(ctx, a.db, a.dir) + return db.CreateOrder(ctx, &Order{ + AccountID: ops.AccountID, + ProvisionerID: prov.GetID(), + Backdate: a.backdate.Duration, + DefaultDuration: prov.DefaultTLSCertDuration(), + Identifiers: ops.Identifiers, + NotBefore: ops.NotBefore, + NotAfter: ops.NotAfter, + }) } // FinalizeOrder attempts to finalize an order and generate a new certificate. @@ -249,44 +250,51 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs if err != nil { return nil, err } - o, err := getOrder(a.db, orderID) + o, err := a.db.GetOrder(ctx, orderID) if err != nil { return nil, err } if accID != o.AccountID { + log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) return nil, UnauthorizedErr(errors.New("account does not own order")) } - o, err = o.finalize(a.db, csr, a.signAuth, prov) + if prov.GetID() != o.ProvisionerID { + log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) + return nil, UnauthorizedErr(errors.New("provisioner does not own order")) + } + o, err = o.Finalize(ctx, a.db, csr, a.signAuth, prov) if err != nil { return nil, Wrap(err, "error finalizing order") } - return o.toACME(ctx, a.db, a.dir) + return o, nil } // GetAuthz retrieves and attempts to update the status on an ACME authz // before returning. func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) { - az, err := getAuthz(a.db, authzID) + az, err := a.db.GetAuthorization(ctx, authzID) if err != nil { return nil, err } - if accID != az.getAccountID() { + if accID != az.AccountID { + log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) return nil, UnauthorizedErr(errors.New("account does not own authz")) } - az, err = az.updateStatus(a.db) + az, err = az.UpdateStatus(ctx, a.db) if err != nil { return nil, Wrap(err, "error updating authz status") } - return az.toACME(ctx, a.db, a.dir) + return az, nil } // ValidateChallenge attempts to validate the challenge. func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { - ch, err := getChallenge(a.db, chID) + ch, err := a.db.GetChallenge(ctx, chID, "todo") if err != nil { return nil, err } - if accID != ch.getAccountID() { + if accID != ch.AccountID { + log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, ch.AccountID) return nil, UnauthorizedErr(errors.New("account does not own challenge")) } client := http.Client{ @@ -295,17 +303,16 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j dialer := &net.Dialer{ Timeout: 30 * time.Second, } - ch, err = ch.validate(a.db, jwk, validateOptions{ + if err = ch.Validate(ctx, a.db, jwk, validateOptions{ httpGet: client.Get, lookupTxt: net.LookupTXT, tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(dialer, network, addr, config) }, - }) - if err != nil { + }); err != nil { return nil, Wrap(err, "error attempting challenge validation") } - return ch.toACME(ctx, a.db, a.dir) + return ch, nil } // GetCertificate retrieves the Certificate by ID. @@ -319,13 +326,3 @@ func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) { } return cert.toACME(a.db, a.dir) } - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -type validateOptions struct { - httpGet httpGetter - lookupTxt lookupTxt - tlsDial tlsDialer -} diff --git a/acme/authorization.go b/acme/authorization.go new file mode 100644 index 00000000..a41950cd --- /dev/null +++ b/acme/authorization.go @@ -0,0 +1,75 @@ +package types + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// Authorization representst an ACME Authorization. +type Authorization struct { + Identifier *Identifier `json:"identifier"` + Status string `json:"status"` + Expires string `json:"expires"` + Challenges []*Challenge `json:"challenges"` + Wildcard bool `json:"wildcard"` + ID string `json:"-"` + AccountID string `json:"-"` +} + +// ToLog enables response logging. +func (az *Authorization) ToLog() (interface{}, error) { + b, err := json.Marshal(az) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) + } + return string(b), nil +} + +// UpdateStatus updates the ACME Authorization Status if necessary. +// Changes to the Authorization are saved using the database interface. +func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { + now := time.Now().UTC() + expiry, err := time.Parse(time.RFC3339, az.Expires) + if err != nil { + return ServerInternalErr(errors.Wrap("error converting expiry string to time")) + } + + switch az.Status { + case StatusInvalid: + return nil + case StatusValid: + return nil + case StatusPending: + // check expiry + if now.After(expiry) { + az.Status = StatusInvalid + az.Error = MalformedErr(errors.New("authz has expired")) + break + } + + var isValid = false + for _, chID := range ba.Challenges { + ch, err := db.GetChallenge(ctx, chID, az.ID) + if err != nil { + return ServerInternalErr(err) + } + if ch.Status == StatusValid { + isValid = true + break + } + } + + if !isValid { + return nil + } + az.Status = StatusValid + az.Error = nil + default: + return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) + } + + return ServerInternalErr(db.UpdateAuthorization(ctx, az)) +} diff --git a/acme/challenge.go b/acme/challenge.go new file mode 100644 index 00000000..de178d6c --- /dev/null +++ b/acme/challenge.go @@ -0,0 +1,262 @@ +package types + +import ( + "context" + "crypto" + "crypto/sha256" + "crypto/subtle" + "crypto/tls" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql" + "go.step.sm/crypto/jose" +) + +// Challenge represents an ACME response Challenge type. +type Challenge struct { + Type string `json:"type"` + Status string `json:"status"` + Token string `json:"token"` + Validated string `json:"validated,omitempty"` + URL string `json:"url"` + Error *AError `json:"error,omitempty"` + ID string `json:"-"` + AuthzID string `json:"-"` + AccountID string `json:"-"` + Value string `json:"-"` +} + +// ToLog enables response logging. +func (ch *Challenge) ToLog() (interface{}, error) { + b, err := json.Marshal(ch) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) + } + return string(b), nil +} + +// Validate attempts to validate the challenge. Stores changes to the Challenge +// type using the DB interface. +// satisfactorily validated, the 'status' and 'validated' attributes are +// updated. +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { + // If already valid or invalid then return without performing validation. + if ch.Status == StatusValid || ch.Status == StatusInvalid { + return nil + } + switch ch.Type { + case "http-01": + return http01Validate(ctx, ch, db, jwk, vo) + case "dns-01": + return dns01Validate(ctx, ch, db, jwk, vo) + case "tls-alpn-01": + return tlsalpn01Validate(ctx, ch, db, jwk, vo) + default: + return ServerInternalErr(errors.Errorf("unexpected challenge type '%s'", ch.Type)) + } +} + +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { + url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token) + + resp, err := vo.httpGet(url) + if err != nil { + return storeError(ctx, ch, db, ConnectionErr(errors.Wrapf(err, + "error doing http GET for url %s", url))) + } + if resp.StatusCode >= 400 { + return storeError(ctx, ch, db, ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", + url, resp.StatusCode))) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return ServerInternalErr(errors.Wrapf(err, "error reading "+ + "response body for url %s", url)) + } + keyAuth := strings.Trim(string(body), "\r\n") + + expected, err := KeyAuthorization(ch.Token, jwk) + if err != nil { + return err + } + if keyAuth != expected { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ + "expected %s, but got %s", expected, keyAuth))) + } + + // Update and store the challenge. + ch.Status = StatusValid + ch.Error = nil + ch.Validated = clock.Now() + + return ServerInternalErr(db.UpdateChallenge(ctx, ch)) +} + +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { + config := &tls.Config{ + NextProtos: []string{"acme-tls/1"}, + ServerName: tc.Value, + InsecureSkipVerify: true, // we expect a self-signed challenge certificate + } + + hostPort := net.JoinHostPort(tc.Value, "443") + + conn, err := vo.tlsDial("tcp", hostPort, config) + if err != nil { + return storeError(ctx, ch, db, ConnectionErr(errors.Wrapf(err, + "error doing TLS dial for %s", hostPort))) + } + defer conn.Close() + + cs := conn.ConnectionState() + certs := cs.PeerCertificates + + if len(certs) == 0 { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("%s "+ + "challenge for %s resulted in no certificates", tc.Type, tc.Value))) + } + + if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("cannot "+ + "negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))) + } + + leafCert := certs[0] + + if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ + "leaf certificate must contain a single DNS name, %v", tc.Value))) + } + + idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} + idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} + foundIDPeAcmeIdentifierV1Obsolete := false + + keyAuth, err := KeyAuthorization(tc.Token, jwk) + if err != nil { + return nil, err + } + hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) + + for _, ext := range leafCert.Extensions { + if idPeAcmeIdentifier.Equal(ext.Id) { + if !ext.Critical { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ + "certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical"))) + } + + var extValue []byte + rest, err := asn1.Unmarshal(ext.Value, &extValue) + + if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ + "certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value"))) + } + + if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))) + } + + ch.Status = StatusValid + ch.Error = nil + ch.Validated = clock.Now() + + return ServerInternalErr(db.UpdateChallenge(ctx, ch)) + } + + if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { + foundIDPeAcmeIdentifierV1Obsolete = true + } + } + + if foundIDPeAcmeIdentifierV1Obsolete { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ + "certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))) + } + + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ + "certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))) +} + +func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) error { + // Normalize domain for wildcard DNS names + // This is done to avoid making TXT lookups for domains like + // _acme-challenge.*.example.com + // Instead perform txt lookup for _acme-challenge.example.com + domain := strings.TrimPrefix(dc.Value, "*.") + + txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) + if err != nil { + return storeError(ctx, ch, db, DNSErr(errors.Wrapf(err, "error looking up TXT "+ + "records for domain %s", domain))) + } + + expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) + if err != nil { + return nil, err + } + h := sha256.Sum256([]byte(expectedKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + var found bool + for _, r := range txtRecords { + if r == expected { + found = true + break + } + } + if !found { + return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ + "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))) + } + + // Update and store the challenge. + ch.Status = StatusValid + ch.Error = nil + ch.Validated = time.Now().UTC() + + return ServerInternalErr(db.UpdateChallenge(ctx, ch)) +} + +// KeyAuthorization creates the ACME key authorization value from a token +// and a jwk. +func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { + thumbprint, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) + } + encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) + return fmt.Sprintf("%s.%s", token, encPrint), nil +} + +// storeError the given error to an ACME error and saves using the DB interface. +func (bc *baseChallenge) storeError(ctx context.Context, ch Challenge, db nosql.DB, err *Error) error { + ch.Error = err.ToACME() + if err := db.UpdateChallenge(ctx, ch); err != nil { + return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge")) + } + return nil +} + +type httpGetter func(string) (*http.Response, error) +type lookupTxt func(string) ([]string, error) +type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) + +type validateOptions struct { + httpGet httpGetter + lookupTxt lookupTxt + tlsDial tlsDialer +} diff --git a/acme/common.go b/acme/common.go index fec47b94..a5a1fe09 100644 --- a/acme/common.go +++ b/acme/common.go @@ -16,6 +16,7 @@ import ( // only those methods required by the ACME api/authority. type Provisioner interface { AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) + GetID() string GetName() string DefaultTLSCertDuration() time.Duration GetOptions() *provisioner.Options @@ -25,6 +26,7 @@ type Provisioner interface { type MockProvisioner struct { Mret1 interface{} Merr error + MgetID func() string MgetName func() string MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MdefaultTLSCertDuration func() time.Duration @@ -55,6 +57,7 @@ func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration { return m.Mret1.(time.Duration) } +// GetOptions mock func (m *MockProvisioner) GetOptions() *provisioner.Options { if m.MgetOptions != nil { return m.MgetOptions() @@ -62,6 +65,14 @@ func (m *MockProvisioner) GetOptions() *provisioner.Options { return m.Mret1.(*provisioner.Options) } +// GetID mock +func (m *MockProvisioner) GetID() string { + if m.MgetID != nil { + return m.MgetID() + } + return m.Mret1.(string) +} + // ContextKey is the key type for storing and searching for ACME request // essentials in the context of a request. type ContextKey string diff --git a/acme/db/db.go b/acme/db.go similarity index 100% rename from acme/db/db.go rename to acme/db.go diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index dd93c5d2..6e9ee8c0 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -25,45 +25,15 @@ func (dba *dbAccount) clone() *dbAccount { return &nu } -func (db *DB) saveDBAccount(nu *dbAccount, old *dbAccount) error { - var ( - err error - oldB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) - } - } - - b, err := json.Marshal(*nu) - if err != nil { - return errors.Wrap(err, "error marshaling new account object") - } - // Set the Account - _, swapped, err := db.CmpAndSwap(accountTable, []byte(nu.ID), oldB, b) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error storing account")) - case !swapped: - return ServerInternalErr(errors.New("error storing account; " + - "value has changed since last read")) - default: - return nil - } -} - // CreateAccount imlements the AcmeDB.CreateAccount interface. func (db *DB) CreateAccount(ctx context.Context, acc *types.Account) error { - id, err := randID() + acc.ID, err = randID() if err != nil { return nil, err } dba := &dbAccount{ - ID: id, + ID: acc.ID, Key: acc.Key, Contact: acc.Contact, Status: acc.Valid, @@ -84,7 +54,7 @@ func (db *DB) CreateAccount(ctx context.Context, acc *types.Account) error { case !swapped: return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) default: - if err = db.saveDBAccount(dba, nil); err != nil { + if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil { db.db.Del(accountByKeyIDTable, kidB) return err } @@ -115,9 +85,11 @@ func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*types.Account // UpdateAccount imlements the AcmeDB.UpdateAccount interface. func (db *DB) UpdateAccount(ctx context.Context, acc *types.Account) error { - kid := "from-context" + if len(acc.ID) == 0 { + return ServerInternalErr(errors.New("id cannot be empty")) + } - old, err := db.getDBAccountByKeyID(ctx, kid) + old, err := db.getDBAccount(ctx, acc.ID) if err != nil { return err } @@ -131,7 +103,7 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *types.Account) error { nu.Deactivated = clock.Now() } - return db.saveDBAccount(newdba, dba) + return db.save(ctx, old.ID, newdba, dba, "account", accountTable) } func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { @@ -145,12 +117,8 @@ func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, erro return string(id), nil } -// getDBAccountByKeyID retrieves Id associated with the given Kid. -func (db *DB) getDBAccountByKeyID(ctx context.Context, kid string) (*dbAccount, error) { - id, err := db.getAccountIDByKeyID(ctx, kid) - if err != nil { - return err - } +// getDBAccount retrieves and unmarshals dbAccount. +func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { data, err := db.db.Get(accountTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 32e42a69..a50d46f1 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -14,15 +14,15 @@ var defaultExpiryDuration = time.Hour * 24 // dbAuthz is the base authz type that others build from. type dbAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier Identifier `json:"identifier"` - Status string `json:"status"` - Expires time.Time `json:"expires"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` - Error *Error `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier *Identifier `json:"identifier"` + Status string `json:"status"` + Expires time.Time `json:"expires"` + Challenges []string `json:"challenges"` + Wildcard bool `json:"wildcard"` + Created time.Time `json:"created"` + Error *Error `json:"error"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -30,34 +30,6 @@ func (ba *dbAuthz) clone() *dbAuthz { return &u } -func (db *DB) saveDBAuthz(ctx context.Context, nu *authz, old *dbAuthz) error { - var ( - err error - oldB, newB []byte - ) - - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old authz")) - } - } - if newB, err = json.Marshal(nu); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new authz")) - } - _, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing authz")) - case !swapped: - return ServerInternalErr(errors.Errorf("error storing authz; " + - "value has changed since last read")) - default: - return nil - } -} - // getDBAuthz retrieves and unmarshals a database representation of the // ACME Authorization type. func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) { @@ -102,8 +74,11 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*types.Authoriza // CreateAuthorization creates an entry in the database for the Authorization. // Implements the acme.DB.CreateAuthorization interface. func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) error { - if len(authz.AccountID) == 0 { - return ServerInternalErr(errors.New("AccountID cannot be empty")) + if len(az.AccountID) == 0 { + return ServerInternalErr(errors.New("account-id cannot be empty")) + } + if az.Identifier == nil { + return ServerInternalErr(errors.New("identifier cannot be nil")) } az.ID, err = randID() if err != nil { @@ -113,7 +88,7 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) now := clock.Now() dbaz := &dbAuthz{ ID: az.ID, - AccountID: az.AccountId, + AccountID: az.AccountID, Status: types.StatusPending, Created: now, Expires: now.Add(defaultExpiryDuration), @@ -150,11 +125,14 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) } dbaz.Challenges = chIDs - return db.saveDBAuthz(ctx, dbaz, nil) + return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable) } // UpdateAuthorization saves an updated ACME Authorization to the database. func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization) error { + if len(az.ID) == 0 { + return ServerInternalErr(errors.New("id cannot be empty")) + } old, err := db.getDBAuthz(ctx, az.ID) if err != nil { return err @@ -164,53 +142,5 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization) nu.Status = az.Status nu.Error = az.Error - return db.saveDBAuthz(ctx, nu, old) -} - -/* -// updateStatus attempts to update the status on a dbAuthz and stores the -// updating object if necessary. -func (ba *dbAuthz) updateStatus(db nosql.DB) (authz, error) { - newAuthz := ba.clone() - - now := time.Now().UTC() - switch ba.Status { - case StatusInvalid: - return ba.parent(), nil - case StatusValid: - return ba.parent(), nil - case StatusPending: - // check expiry - if now.After(ba.Expires) { - newAuthz.Status = StatusInvalid - newAuthz.Error = MalformedErr(errors.New("authz has expired")) - break - } - - var isValid = false - for _, chID := range ba.Challenges { - ch, err := getChallenge(db, chID) - if err != nil { - return ba, err - } - if ch.getStatus() == StatusValid { - isValid = true - break - } - } - - if !isValid { - return ba.parent(), nil - } - newAuthz.Status = StatusValid - newAuthz.Error = nil - default: - return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) - } - - if err := newAuthz.save(db, ba); err != nil { - return ba, err - } - return newAuthz.parent(), nil + return db.save(ctx, old.ID, nu, old, "authz", authzTable) } -*/ diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index 6303f005..bd3be0d0 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -35,39 +35,6 @@ func (dbc *dbChallenge) clone() *dbChallenge { return &u } -// save writes the challenge to disk. For new challenges 'old' should be nil, -// otherwise 'old' should be a pointer to the acme challenge as it was at the -// start of the request. This method will fail if the value currently found -// in the bucket/row does not match the value of 'old'. -func (db *DB) saveDBChallenge(ctx context.Context, nu challenge, old challenge) error { - newB, err := json.Marshal(nu) - if err != nil { - return ServerInternalErr(errors.Wrap(err, - "error marshaling new acme challenge")) - } - var oldB []byte - if old == nil { - oldB = nil - } else { - oldB, err = json.Marshal(old) - if err != nil { - return ServerInternalErr(errors.Wrap(err, - "error marshaling old acme challenge")) - } - } - - _, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error saving acme challenge")) - case !swapped: - return ServerInternalErr(errors.New("error saving acme challenge; " + - "acme challenge has changed since last read")) - default: - return nil - } -} - func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { data, err := db.db.Get(challengeTable, []byte(id)) if nosql.IsErrNotFound(err) { @@ -121,7 +88,7 @@ func (db *DB) CreateChallenge(ctx context.context, ch *types.Challenge) error { Type: ch.Type, } - return dbch.saveDBChallenge(ctx, dbch, nil) + return dbch.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable) } // GetChallenge retrieves and unmarshals an ACME challenge type from the database. @@ -149,7 +116,10 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*types.Chal // UpdateChallenge updates an ACME challenge type in the database. func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error { - old, err := db.getDBChallenge(ctx, id) + if len(ch.ID) == 0 { + return ServerInternalErr(errors.New("id cannot be empty")) + } + old, err := db.getDBChallenge(ctx, ch.ID) if err != nil { return err } @@ -163,366 +133,5 @@ func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error { nu.Validated = clock.Now() } - return db.saveDBChallenge(ctx, nu, old) + return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) } - -//// http01Challenge represents an http-01 acme challenge. -//type http01Challenge struct { -// *baseChallenge -//} -// -//// newHTTP01Challenge returns a new acme http-01 challenge. -//func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { -// bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) -// if err != nil { -// return nil, err -// } -// bc.Type = "http-01" -// bc.Value = ops.Identifier.Value -// -// hc := &http01Challenge{bc} -// if err := hc.save(db, nil); err != nil { -// return nil, err -// } -// return hc, nil -//} -// -//// Validate attempts to validate the challenge. If the challenge has been -//// satisfactorily validated, the 'status' and 'validated' attributes are -//// updated. -//func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { -// // If already valid or invalid then return without performing validation. -// if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid { -// return hc, nil -// } -// url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token) -// -// resp, err := vo.httpGet(url) -// if err != nil { -// if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err, -// "error doing http GET for url %s", url))); err != nil { -// return nil, err -// } -// return hc, nil -// } -// if resp.StatusCode >= 400 { -// if err = hc.storeError(db, -// ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", -// url, resp.StatusCode))); err != nil { -// return nil, err -// } -// return hc, nil -// } -// defer resp.Body.Close() -// -// body, err := ioutil.ReadAll(resp.Body) -// if err != nil { -// return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+ -// "response body for url %s", url)) -// } -// keyAuth := strings.Trim(string(body), "\r\n") -// -// expected, err := KeyAuthorization(hc.Token, jwk) -// if err != nil { -// return nil, err -// } -// if keyAuth != expected { -// if err = hc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ -// "expected %s, but got %s", expected, keyAuth))); err != nil { -// return nil, err -// } -// return hc, nil -// } -// -// // Update and store the challenge. -// upd := &http01Challenge{hc.baseChallenge.clone()} -// upd.Status = StatusValid -// upd.Error = nil -// upd.Validated = clock.Now() -// -// if err := upd.save(db, hc); err != nil { -// return nil, err -// } -// return upd, nil -//} -// -//type tlsALPN01Challenge struct { -// *baseChallenge -//} -// -//// newTLSALPN01Challenge returns a new acme tls-alpn-01 challenge. -//func newTLSALPN01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { -// bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) -// if err != nil { -// return nil, err -// } -// bc.Type = "tls-alpn-01" -// bc.Value = ops.Identifier.Value -// -// hc := &tlsALPN01Challenge{bc} -// if err := hc.save(db, nil); err != nil { -// return nil, err -// } -// return hc, nil -//} -// -//func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { -// // If already valid or invalid then return without performing validation. -// if tc.getStatus() == StatusValid || tc.getStatus() == StatusInvalid { -// return tc, nil -// } -// -// config := &tls.Config{ -// NextProtos: []string{"acme-tls/1"}, -// ServerName: tc.Value, -// InsecureSkipVerify: true, // we expect a self-signed challenge certificate -// } -// -// hostPort := net.JoinHostPort(tc.Value, "443") -// -// conn, err := vo.tlsDial("tcp", hostPort, config) -// if err != nil { -// if err = tc.storeError(db, -// ConnectionErr(errors.Wrapf(err, "error doing TLS dial for %s", hostPort))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// defer conn.Close() -// -// cs := conn.ConnectionState() -// certs := cs.PeerCertificates -// -// if len(certs) == 0 { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("%s challenge for %s resulted in no certificates", -// tc.Type, tc.Value))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("cannot negotiate ALPN acme-tls/1 protocol for "+ -// "tls-alpn-01 challenge"))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// leafCert := certs[0] -// -// if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ -// "leaf certificate must contain a single DNS name, %v", tc.Value))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} -// idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} -// foundIDPeAcmeIdentifierV1Obsolete := false -// -// keyAuth, err := KeyAuthorization(tc.Token, jwk) -// if err != nil { -// return nil, err -// } -// hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) -// -// for _, ext := range leafCert.Extensions { -// if idPeAcmeIdentifier.Equal(ext.Id) { -// if !ext.Critical { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ -// "acmeValidationV1 extension not critical"))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// var extValue []byte -// rest, err := asn1.Unmarshal(ext.Value, &extValue) -// -// if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ -// "malformed acmeValidationV1 extension value"))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ -// "expected acmeValidationV1 extension value %s for this challenge but got %s", -// hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// upd := &tlsALPN01Challenge{tc.baseChallenge.clone()} -// upd.Status = StatusValid -// upd.Error = nil -// upd.Validated = clock.Now() -// -// if err := upd.save(db, tc); err != nil { -// return nil, err -// } -// return upd, nil -// } -// -// if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { -// foundIDPeAcmeIdentifierV1Obsolete = true -// } -// } -// -// if foundIDPeAcmeIdentifierV1Obsolete { -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ -// "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))); err != nil { -// return nil, err -// } -// return tc, nil -// } -// -// if err = tc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ -// "missing acmeValidationV1 extension"))); err != nil { -// return nil, err -// } -// return tc, nil -//} -// -//// dns01Challenge represents an dns-01 acme challenge. -//type dns01Challenge struct { -// *baseChallenge -//} -// -//// newDNS01Challenge returns a new acme dns-01 challenge. -//func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { -// bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) -// if err != nil { -// return nil, err -// } -// bc.Type = "dns-01" -// bc.Value = ops.Identifier.Value -// -// dc := &dns01Challenge{bc} -// if err := dc.save(db, nil); err != nil { -// return nil, err -// } -// return dc, nil -//} -// -//// KeyAuthorization creates the ACME key authorization value from a token -//// and a jwk. -//func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { -// thumbprint, err := jwk.Thumbprint(crypto.SHA256) -// if err != nil { -// return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) -// } -// encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) -// return fmt.Sprintf("%s.%s", token, encPrint), nil -//} -// -//// validate attempts to validate the challenge. If the challenge has been -//// satisfactorily validated, the 'status' and 'validated' attributes are -//// updated. -//func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { -// // If already valid or invalid then return without performing validation. -// if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid { -// return dc, nil -// } -// -// // Normalize domain for wildcard DNS names -// // This is done to avoid making TXT lookups for domains like -// // _acme-challenge.*.example.com -// // Instead perform txt lookup for _acme-challenge.example.com -// domain := strings.TrimPrefix(dc.Value, "*.") -// -// txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) -// if err != nil { -// if err = dc.storeError(db, -// DNSErr(errors.Wrapf(err, "error looking up TXT "+ -// "records for domain %s", domain))); err != nil { -// return nil, err -// } -// return dc, nil -// } -// -// expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) -// if err != nil { -// return nil, err -// } -// h := sha256.Sum256([]byte(expectedKeyAuth)) -// expected := base64.RawURLEncoding.EncodeToString(h[:]) -// var found bool -// for _, r := range txtRecords { -// if r == expected { -// found = true -// break -// } -// } -// if !found { -// if err = dc.storeError(db, -// RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ -// "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil { -// return nil, err -// } -// return dc, nil -// } -// -// // Update and store the challenge. -// upd := &dns01Challenge{dc.baseChallenge.clone()} -// upd.Status = StatusValid -// upd.Error = nil -// upd.Validated = time.Now().UTC() -// -// if err := upd.save(db, dc); err != nil { -// return nil, err -// } -// return upd, nil -//} -// -//// unmarshalChallenge unmarshals a challenge type into the correct sub-type. -//func unmarshalChallenge(data []byte) (challenge, error) { -// var getType struct { -// Type string `json:"type"` -// } -// if err := json.Unmarshal(data, &getType); err != nil { -// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type")) -// } -// -// switch getType.Type { -// case "dns-01": -// var bc baseChallenge -// if err := json.Unmarshal(data, &bc); err != nil { -// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ -// "challenge type into dns01Challenge")) -// } -// return &dns01Challenge{&bc}, nil -// case "http-01": -// var bc baseChallenge -// if err := json.Unmarshal(data, &bc); err != nil { -// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ -// "challenge type into http01Challenge")) -// } -// return &http01Challenge{&bc}, nil -// case "tls-alpn-01": -// var bc baseChallenge -// if err := json.Unmarshal(data, &bc); err != nil { -// return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ -// "challenge type into tlsALPN01Challenge")) -// } -// return &tlsALPN01Challenge{&bc}, nil -// default: -// return nil, ServerInternalErr(errors.Errorf("unexpected challenge type '%s'", getType.Type)) -// } -//} -// diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index 3459f212..f8f57f89 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -33,16 +33,10 @@ func (db *DB) CreateNonce() (Nonce, error) { if err != nil { return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) } - _, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b) - switch { - case err != nil: - return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce")) - case !swapped: - return nil, ServerInternalErr(errors.New("error storing nonce; " + - "value has changed since last read")) - default: - return Nonce(id), nil + if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil { + return "", err } + return Nonce(id), nil } // DeleteNonce verifies that the nonce is valid (by checking if it exists), diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index 8bfd1a66..e11b92b2 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -1,10 +1,56 @@ package nosql import ( + "context" + "encoding/json" + + "github.com/pkg/errors" nosqlDB "github.com/smallstep/nosql" ) +var ( + accountTable = []byte("acme_accounts") + accountByKeyIDTable = []byte("acme_keyID_accountID_index") + authzTable = []byte("acme_authzs") + challengeTable = []byte("acme_challenges") + nonceTable = []byte("nonces") + orderTable = []byte("acme_orders") + ordersByAccountIDTable = []byte("acme_account_orders_index") + certTable = []byte("acme_certs") +) + // DB is a struct that implements the AcmeDB interface. type DB struct { db nosqlDB.DB } + +// save writes the new data to the database, overwriting the old data if it +// existed. +func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { + newB, err := json.Marshal(nu) + if err != nil { + return ServerInternalErr(errors.Wrapf(err, + "error marshaling new acme %s", typ)) + } + var oldB []byte + if old == nil { + oldB = nil + } else { + oldB, err = json.Marshal(old) + if err != nil { + return ServerInternalErr(errors.Wrapf(err, + "error marshaling old acme %s", typ)) + } + } + + _, swapped, err := db.CmpAndSwap(table, []byte(id), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error saving acme %s", typ)) + case !swapped: + return ServerInternalErr(errors.Errorf("error saving acme %s; "+ + "changed since last read", typ)) + default: + return nil + } +} diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 5ba54790..a0ab60da 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -2,17 +2,13 @@ package nosql import ( "context" - "crypto/x509" "encoding/json" - "sort" - "strings" "sync" "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" - "go.step.sm/crypto/x509util" ) var defaultOrderExpiry = time.Hour * 24 @@ -20,19 +16,10 @@ var defaultOrderExpiry = time.Hour * 24 // Mutex for locking ordersByAccount index operations. var ordersByAccountMux sync.Mutex -// OrderOptions options with which to create a new Order. -type OrderOptions struct { - AccountID string `json:"accID"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore"` - NotAfter time.Time `json:"notAfter"` - backdate time.Duration - defaultDuration time.Duration -} - type dbOrder struct { ID string `json:"id"` AccountID string `json:"accountID"` + ProvisionerID string `json:"provisionerID"` Created time.Time `json:"created"` Expires time.Time `json:"expires,omitempty"` Status string `json:"status"` @@ -60,7 +47,7 @@ func (db *DB) getDBOrder(id string) (*dbOrder, error) { } // GetOrder retrieves an ACME Order from the database. -func (db *DB) GetOrder(id string) (*types.Order, error) { +func (db *DB) GetOrder(id string) (*acme.Order, error) { dbo, err := db.getDBOrder(id) azs := make([]string, len(dbo.Authorizations)) @@ -76,6 +63,7 @@ func (db *DB) GetOrder(id string) (*types.Order, error) { Authorizations: azs, Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), ID: dbo.ID, + ProvisionerID: dbo.ProvisionerID, } if dbo.Certificate != "" { @@ -84,60 +72,74 @@ func (db *DB) GetOrder(id string) (*types.Order, error) { return o, nil } -func (db *DB) CreateOrder(ctx context.Context, o *types.Order) error { - o.ID, err := randID() +// CreateOrder creates ACME Order resources and saves them to the DB. +func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { + if len(o.AccountID) == 0 { + return ServerInternalErr(errors.New("account-id cannot be empty")) + } + if len(o.ProvisionerID) == 0 { + return ServerInternalErr(errors.New("provisioner-id cannot be empty")) + } + if len(o.Identifiers) == 0 { + return ServerInternalErr(errors.New("identifiers cannot be empty")) + } + if o.DefaultDuration == 0 { + return ServerInternalErr(errors.New("default-duration cannot be empty")) + } + + o.ID, err = randID() if err != nil { return nil, err } - authzs := make([]string, len(ops.Identifiers)) + azIDs := make([]string, len(ops.Identifiers)) for i, identifier := range ops.Identifiers { - az, err := newAuthz(db, ops.AccountID, identifier) + az, err = db.CreateAuthorzation(&types.Authorization{ + AccountID: o.AccountID, + Identifier: o.Identifier, + }) if err != nil { - return nil, err + return err } - authzs[i] = az.getID() + azIDs[i] = az.ID } now := clock.Now() var backdate time.Duration - nbf := ops.NotBefore + nbf := o.NotBefore if nbf.IsZero() { nbf = now - backdate = -1 * ops.backdate + backdate = -1 * o.Backdate } - naf := ops.NotAfter + naf := o.NotAfter if naf.IsZero() { - naf = nbf.Add(ops.defaultDuration) + naf = nbf.Add(o.DefaultDuration) } dbo := &dbOrder{ - ID: id, - AccountID: ops.AccountID, + ID: o.ID, + AccountID: o.AccountID, + ProvisionerID: o.ProvisionerID, Created: now, Status: StatusPending, Expires: now.Add(defaultOrderExpiry), Identifiers: ops.Identifiers, NotBefore: nbf.Add(backdate), NotAfter: naf, - Authorizations: authzs, + Authorizations: azIDs, } - if err := db.saveDBOrder(dbo, nil); err != nil { + if err := db.save(ctx, o.ID, dbo, nil, orderTable); err != nil { return nil, err } var oidHelper = orderIDsByAccount{} - _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID) + _, err = oidHelper.addOrderID(db, o.AccountID, o.ID) if err != nil { return nil, err } return o, nil } -// newOrder returns a new Order type. -func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { -} - type orderIDsByAccount struct{} // addOrderID adds an order ID to a users index of in progress order IDs. @@ -188,7 +190,7 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri if err != nil { return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) } - if o, err = o.updateStatus(db); err != nil { + if o, err = o.UpdateStatus(db); err != nil { return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) } if o.Status == StatusPending { @@ -198,7 +200,7 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri // If the number of pending orders is less than the number of orders in the // list, then update the pending order list. if len(pendOids) != len(oids) { - if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + if err = orderIDs(pendOiUs).save(db, oids, accID); err != nil { return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ "len(orderIDs) = %d", len(pendOids))) } @@ -248,225 +250,3 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { return nil } } - -func (o *order) save(db nosql.DB, old *order) error { - var ( - err error - oldB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) - } - } - - newB, err := json.Marshal(o) - if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order")) - } - - _, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error storing order")) - case !swapped: - return ServerInternalErr(errors.New("error storing order; " + - "value has changed since last read")) - default: - return nil - } -} - -// updateStatus updates order status if necessary. -func (o *order) updateStatus(db nosql.DB) (*order, error) { - _newOrder := *o - newOrder := &_newOrder - - now := time.Now().UTC() - switch o.Status { - case StatusInvalid: - return o, nil - case StatusValid: - return o, nil - case StatusReady: - // check expiry - if now.After(o.Expires) { - newOrder.Status = StatusInvalid - newOrder.Error = MalformedErr(errors.New("order has expired")) - break - } - return o, nil - case StatusPending: - // check expiry - if now.After(o.Expires) { - newOrder.Status = StatusInvalid - newOrder.Error = MalformedErr(errors.New("order has expired")) - break - } - - var count = map[string]int{ - StatusValid: 0, - StatusInvalid: 0, - StatusPending: 0, - } - for _, azID := range o.Authorizations { - az, err := getAuthz(db, azID) - if err != nil { - return nil, err - } - if az, err = az.updateStatus(db); err != nil { - return nil, err - } - st := az.getStatus() - count[st]++ - } - switch { - case count[StatusInvalid] > 0: - newOrder.Status = StatusInvalid - - // No change in the order status, so just return the order as is - - // without writing any changes. - case count[StatusPending] > 0: - return newOrder, nil - - case count[StatusValid] == len(o.Authorizations): - newOrder.Status = StatusReady - - default: - return nil, ServerInternalErr(errors.New("unexpected authz status")) - } - default: - return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) - } - - if err := newOrder.save(db, o); err != nil { - return nil, err - } - return newOrder, nil -} - -// finalize signs a certificate if the necessary conditions for Order completion -// have been met. -func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) { - var err error - if o, err = o.updateStatus(db); err != nil { - return nil, err - } - - switch o.Status { - case StatusInvalid: - return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) - case StatusValid: - return o, nil - case StatusPending: - return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) - case StatusReady: - break - default: - return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) - } - - // RFC8555: The CSR MUST indicate the exact same set of requested - // identifiers as the initial newOrder request. Identifiers of type "dns" - // MUST appear either in the commonName portion of the requested subject - // name or in an extensionRequest attribute [RFC2985] requesting a - // subjectAltName extension, or both. - if csr.Subject.CommonName != "" { - csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) - } - csr.DNSNames = uniqueLowerNames(csr.DNSNames) - orderNames := make([]string, len(o.Identifiers)) - for i, n := range o.Identifiers { - orderNames[i] = n.Value - } - orderNames = uniqueLowerNames(orderNames) - - // Validate identifier names against CSR alternative names. - // - // Note that with certificate templates we are not going to check for the - // absence of other SANs as they will only be set if the templates allows - // them. - if len(csr.DNSNames) != len(orderNames) { - return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) - } - - sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) - for i := range csr.DNSNames { - if csr.DNSNames[i] != orderNames[i] { - return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) - } - sans[i] = x509util.SubjectAlternativeName{ - Type: x509util.DNSType, - Value: csr.DNSNames[i], - } - } - - // Get authorizations from the ACME provisioner. - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) - signOps, err := p.AuthorizeSign(ctx, "") - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) - } - - // Template data - data := x509util.NewTemplateData() - data.SetCommonName(csr.Subject.CommonName) - data.Set(x509util.SANsKey, sans) - - templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) - } - signOps = append(signOps, templateOptions) - - // Create and store a new certificate. - certChain, err := auth.Sign(csr, provisioner.SignOptions{ - NotBefore: provisioner.NewTimeDuration(o.NotBefore), - NotAfter: provisioner.NewTimeDuration(o.NotAfter), - }, signOps...) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) - } - - cert, err := newCert(db, CertOptions{ - AccountID: o.AccountID, - OrderID: o.ID, - Leaf: certChain[0], - Intermediates: certChain[1:], - }) - if err != nil { - return nil, err - } - - _newOrder := *o - newOrder := &_newOrder - newOrder.Certificate = cert.ID - newOrder.Status = StatusValid - if err := newOrder.save(db, o); err != nil { - return nil, err - } - return newOrder, nil -} - -// toACME converts the internal Order type into the public acmeOrder type for -// presentation in the ACME protocol. -func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) { -} - -// uniqueLowerNames returns the set of all unique names in the input after all -// of them are lowercased. The returned names will be in their lowercased form -// and sorted alphabetically. -func uniqueLowerNames(names []string) (unique []string) { - nameMap := make(map[string]int, len(names)) - for _, name := range names { - nameMap[strings.ToLower(name)] = 1 - } - unique = make([]string, 0, len(nameMap)) - for name := range nameMap { - unique = append(unique, name) - } - sort.Strings(unique) - return -} diff --git a/acme/errors.go b/acme/errors.go index a4dd8159..9bd9c400 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -186,6 +186,9 @@ func RejectedIdentifierErr(err error) *Error { // ServerInternalErr returns a new acme error. func ServerInternalErr(err error) *Error { + if err == nil { + return nil + } return &Error{ Type: serverInternalErr, Detail: "The server experienced an internal error", diff --git a/acme/types/nonce.go b/acme/nonce.go similarity index 100% rename from acme/types/nonce.go rename to acme/nonce.go diff --git a/acme/order.go b/acme/order.go index 574477ca..16a0ead2 100644 --- a/acme/order.go +++ b/acme/order.go @@ -1,4 +1,3 @@ -package acme import ( "context" @@ -6,32 +5,28 @@ import ( "encoding/json" "sort" "strings" - "sync" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/nosql" "go.step.sm/crypto/x509util" ) -var defaultOrderExpiry = time.Hour * 24 - -// Mutex for locking ordersByAccount index operations. -var ordersByAccountMux sync.Mutex - // Order contains order metadata for the ACME protocol order type. type Order struct { - Status string `json:"status"` - Expires string `json:"expires,omitempty"` - Identifiers []Identifier `json:"identifiers"` - NotBefore string `json:"notBefore,omitempty"` - NotAfter string `json:"notAfter,omitempty"` - Error interface{} `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - Finalize string `json:"finalize"` - Certificate string `json:"certificate,omitempty"` - ID string `json:"-"` + Status string `json:"status"` + Expires string `json:"expires,omitempty"` + Identifiers []Identifier `json:"identifiers"` + NotBefore string `json:"notBefore,omitempty"` + NotAfter string `json:"notAfter,omitempty"` + Error interface{} `json:"error,omitempty"` + Authorizations []string `json:"authorizations"` + Finalize string `json:"finalize"` + Certificate string `json:"certificate,omitempty"` + ID string `json:"-"` + ProvisionerID string `json:"-"` + DefaultDuration time.Duration `json:"-"` + Backdate time.Duration `json:"-"` } // ToLog enables response logging. @@ -43,251 +38,33 @@ func (o *Order) ToLog() (interface{}, error) { return string(b), nil } -// GetID returns the Order ID. -func (o *Order) GetID() string { - return o.ID -} - -// OrderOptions options with which to create a new Order. -type OrderOptions struct { - AccountID string `json:"accID"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore"` - NotAfter time.Time `json:"notAfter"` - backdate time.Duration - defaultDuration time.Duration -} - -type order struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires,omitempty"` - Status string `json:"status"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore,omitempty"` - NotAfter time.Time `json:"notAfter,omitempty"` - Error *Error `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - Certificate string `json:"certificate,omitempty"` -} - -// newOrder returns a new Order type. -func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { - id, err := randID() - if err != nil { - return nil, err - } - - authzs := make([]string, len(ops.Identifiers)) - for i, identifier := range ops.Identifiers { - az, err := newAuthz(db, ops.AccountID, identifier) - if err != nil { - return nil, err - } - authzs[i] = az.getID() - } - - now := clock.Now() - var backdate time.Duration - nbf := ops.NotBefore - if nbf.IsZero() { - nbf = now - backdate = -1 * ops.backdate - } - naf := ops.NotAfter - if naf.IsZero() { - naf = nbf.Add(ops.defaultDuration) - } - - o := &order{ - ID: id, - AccountID: ops.AccountID, - Created: now, - Status: StatusPending, - Expires: now.Add(defaultOrderExpiry), - Identifiers: ops.Identifiers, - NotBefore: nbf.Add(backdate), - NotAfter: naf, - Authorizations: authzs, - } - if err := o.save(db, nil); err != nil { - return nil, err - } - - var oidHelper = orderIDsByAccount{} - _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID) - if err != nil { - return nil, err - } - return o, nil -} - -type orderIDsByAccount struct{} - -// addOrderID adds an order ID to a users index of in progress order IDs. -// This method will also cull any orders that are no longer in the `pending` -// state from the index before returning it. -func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { - ordersByAccountMux.Lock() - defer ordersByAccountMux.Unlock() - - // Update the "order IDs by account ID" index - oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID) - if err != nil { - return nil, err - } - newOids := append(oids, oid) - if err = orderIDs(newOids).save(db, oids, accID); err != nil { - // Delete the entire order if storing the index fails. - db.Del(orderTable, []byte(oid)) - return nil, err - } - return newOids, nil -} - -// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the -// account. -func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { - b, err := db.Get(ordersByAccountIDTable, []byte(accID)) - if err != nil { - if nosql.IsErrNotFound(err) { - return []string{}, nil - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) - } - var oids []string - if err := json.Unmarshal(b, &oids); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) - } - - // Remove any order that is not in PENDING state and update the stored list - // before returning. - // - // According to RFC 8555: - // The server SHOULD include pending orders and SHOULD NOT include orders - // that are invalid in the array of URLs. - pendOids := []string{} - for _, oid := range oids { - o, err := getOrder(db, oid) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) - } - if o, err = o.updateStatus(db); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) - } - if o.Status == StatusPending { - pendOids = append(pendOids, oid) - } - } - // If the number of pending orders is less than the number of orders in the - // list, then update the pending order list. - if len(pendOids) != len(oids) { - if err = orderIDs(pendOids).save(db, oids, accID); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids))) - } - } - - return pendOids, nil -} - -type orderIDs []string - -// save is used to update the list of orderIDs keyed by ACME account ID -// stored in the database. -// -// This method always converts empty lists to 'nil' when storing to the DB. We -// do this to avoid any confusion between an empty list and a nil value in the -// db. -func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { - var ( - err error - oldb []byte - newb []byte - ) - if len(old) == 0 { - oldb = nil - } else { - oldb, err = json.Marshal(old) - if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) - } - } - if len(oids) == 0 { - newb = nil - } else { - newb, err = json.Marshal(oids) - if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) - } - } - _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) - switch { - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) - case !swapped: - return ServerInternalErr(errors.Errorf("error storing order IDs "+ - "for account %s; order IDs changed since last read", accID)) - default: - return nil - } -} - -func (o *order) save(db nosql.DB, old *order) error { - var ( - err error - oldB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) - } - } - - newB, err := json.Marshal(o) +// UpdateStatus updates the ACME Order Status if necessary. +// Changes to the order are saved using the database interface. +func (o *Order) UpdateStatus(ctx context.Context, db DB) error { + now := time.Now().UTC() + expiry, err := time.Parse(time.RFC3339, o.Expires) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order")) + return ServerInternalErr(errors.Wrap("error converting expiry string to time")) } - _, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error storing order")) - case !swapped: - return ServerInternalErr(errors.New("error storing order; " + - "value has changed since last read")) - default: - return nil - } -} - -// updateStatus updates order status if necessary. -func (o *order) updateStatus(db nosql.DB) (*order, error) { - _newOrder := *o - newOrder := &_newOrder - - now := time.Now().UTC() switch o.Status { case StatusInvalid: - return o, nil + return nil case StatusValid: - return o, nil + return nil case StatusReady: - // check expiry - if now.After(o.Expires) { - newOrder.Status = StatusInvalid - newOrder.Error = MalformedErr(errors.New("order has expired")) + // Check expiry + if now.After(expiry) { + o.Status = StatusInvalid + o.Error = MalformedErr(errors.New("order has expired")) break } - return o, nil + return nil case StatusPending: - // check expiry - if now.After(o.Expires) { - newOrder.Status = StatusInvalid - newOrder.Error = MalformedErr(errors.New("order has expired")) + // Check expiry + if now.After(expiry) { + o.Status = StatusInvalid + o.Error = MalformedErr(errors.New("order has expired")) break } @@ -297,27 +74,27 @@ func (o *order) updateStatus(db nosql.DB) (*order, error) { StatusPending: 0, } for _, azID := range o.Authorizations { - az, err := getAuthz(db, azID) + az, err := db.GetAuthorization(ctx, azID) if err != nil { - return nil, err + return false, err } - if az, err = az.updateStatus(db); err != nil { - return nil, err + if az, err = az.UpdateStatus(db); err != nil { + return false, err } - st := az.getStatus() + st := az.Status count[st]++ } switch { case count[StatusInvalid] > 0: - newOrder.Status = StatusInvalid + o.Status = StatusInvalid // No change in the order status, so just return the order as is - // without writing any changes. case count[StatusPending] > 0: - return newOrder, nil + return nil case count[StatusValid] == len(o.Authorizations): - newOrder.Status = StatusReady + o.Status = StatusReady default: return nil, ServerInternalErr(errors.New("unexpected authz status")) @@ -325,28 +102,24 @@ func (o *order) updateStatus(db nosql.DB) (*order, error) { default: return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) } - - if err := newOrder.save(db, o); err != nil { - return nil, err - } - return newOrder, nil + return db.UpdateOrder(ctx, o) } // finalize signs a certificate if the necessary conditions for Order completion // have been met. -func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) { +func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { var err error - if o, err = o.updateStatus(db); err != nil { + if o, err = o.UpdateStatus(db); err != nil { return nil, err } switch o.Status { case StatusInvalid: - return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) + return OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) case StatusValid: - return o, nil + return nil case StatusPending: - return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) + return OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) case StatusReady: break default: @@ -366,7 +139,7 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut for i, n := range o.Identifiers { orderNames[i] = n.Value } - orderNames = uniqueLowerNames(orderNames) + orderNames = uniqueSortedLowerNames(orderNames) // Validate identifier names against CSR alternative names. // @@ -425,59 +198,15 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut return nil, err } - _newOrder := *o - newOrder := &_newOrder - newOrder.Certificate = cert.ID - newOrder.Status = StatusValid - if err := newOrder.save(db, o); err != nil { - return nil, err - } - return newOrder, nil -} - -// getOrder retrieves and unmarshals an ACME Order type from the database. -func getOrder(db nosql.DB, id string) (*order, error) { - b, err := db.Get(orderTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id)) - } - var o order - if err := json.Unmarshal(b, &o); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order")) - } - return &o, nil -} - -// toACME converts the internal Order type into the public acmeOrder type for -// presentation in the ACME protocol. -func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) { - azs := make([]string, len(o.Authorizations)) - for i, aid := range o.Authorizations { - azs[i] = dir.getLink(ctx, AuthzLink, true, aid) - } - ao := &Order{ - Status: o.Status, - Expires: o.Expires.Format(time.RFC3339), - Identifiers: o.Identifiers, - NotBefore: o.NotBefore.Format(time.RFC3339), - NotAfter: o.NotAfter.Format(time.RFC3339), - Authorizations: azs, - Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), - ID: o.ID, - } - - if o.Certificate != "" { - ao.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate) - } - return ao, nil + o.Certificate = cert.ID + o.Status = StatusValid + return db.UpdateOrder(ctx, o) } -// uniqueLowerNames returns the set of all unique names in the input after all +// uniqueSortedLowerNames returns the set of all unique names in the input after all // of them are lowercased. The returned names will be in their lowercased form // and sorted alphabetically. -func uniqueLowerNames(names []string) (unique []string) { +func uniqueSortedLowerNames(names []string) (unique []string) { nameMap := make(map[string]int, len(names)) for _, name := range names { nameMap[strings.ToLower(name)] = 1 diff --git a/acme/types/status.go b/acme/status.go similarity index 100% rename from acme/types/status.go rename to acme/status.go diff --git a/acme/types/authz.go b/acme/types/authz.go deleted file mode 100644 index 4119f6c1..00000000 --- a/acme/types/authz.go +++ /dev/null @@ -1,27 +0,0 @@ -package types - -import ( - "encoding/json" - - "github.com/pkg/errors" -) - -// Authorization representst an ACME Authorization. -type Authorization struct { - Identifier Identifier `json:"identifier"` - Status string `json:"status"` - Expires string `json:"expires"` - Challenges []*Challenge `json:"challenges"` - Wildcard bool `json:"wildcard"` - ID string `json:"-"` - AccountID string `json:"-"` -} - -// ToLog enables response logging. -func (a *Authz) ToLog() (interface{}, error) { - b, err := json.Marshal(a) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) - } - return string(b), nil -} diff --git a/acme/types/challenge.go b/acme/types/challenge.go deleted file mode 100644 index 61bcd2fb..00000000 --- a/acme/types/challenge.go +++ /dev/null @@ -1,30 +0,0 @@ -package types - -import ( - "encoding/json" - - "github.com/pkg/errors" -) - -// Challenge represents an ACME response Challenge type. -type Challenge struct { - Type string `json:"type"` - Status string `json:"status"` - Token string `json:"token"` - Validated string `json:"validated,omitempty"` - URL string `json:"url"` - Error *AError `json:"error,omitempty"` - ID string `json:"-"` - AuthzID string `json:"-"` - AccountID string `json:"-"` - Value string `json:"-"` -} - -// ToLog enables response logging. -func (c *Challenge) ToLog() (interface{}, error) { - b, err := json.Marshal(c) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) - } - return string(b), nil -} diff --git a/acme/types/order.go b/acme/types/order.go deleted file mode 100644 index 7e472204..00000000 --- a/acme/types/order.go +++ /dev/null @@ -1,31 +0,0 @@ -package types - -import ( - "encoding/json" - - "github.com/pkg/errors" -) - -// Order contains order metadata for the ACME protocol order type. -type Order struct { - Status string `json:"status"` - Expires string `json:"expires,omitempty"` - Identifiers []Identifier `json:"identifiers"` - NotBefore string `json:"notBefore,omitempty"` - NotAfter string `json:"notAfter,omitempty"` - Error interface{} `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - Finalize string `json:"finalize"` - Certificate string `json:"certificate,omitempty"` - ID string `json:"-"` - ProvisionerID string `json:"-"` -} - -// ToLog enables response logging. -func (o *Order) ToLog() (interface{}, error) { - b, err := json.Marshal(o) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging")) - } - return string(b), nil -} From 121cc34cca32a1e7afe349abc7935b5ea62d9f00 Mon Sep 17 00:00:00 2001 From: max furman Date: Sun, 28 Feb 2021 10:09:06 -0800 Subject: [PATCH 06/47] [acme db interface] wip --- acme/account.go | 26 +------- acme/authority.go | 55 ++++++++-------- acme/authorization.go | 2 +- acme/certificate.go | 80 +++-------------------- acme/challenge.go | 2 +- acme/db.go | 34 +++++----- acme/db/nosql/account.go | 16 +++-- acme/db/nosql/certificate.go | 121 +++++++++++++++++++++++++++++++++++ acme/order.go | 3 +- acme/status.go | 2 +- 10 files changed, 190 insertions(+), 151 deletions(-) diff --git a/acme/account.go b/acme/account.go index ea40a646..a0f88d49 100644 --- a/acme/account.go +++ b/acme/account.go @@ -38,29 +38,5 @@ func (a *Account) GetKey() *jose.JSONWebKey { // IsValid returns true if the Account is valid. func (a *Account) IsValid() bool { - return a.Status == StatusValid + return Status(a.Status) == StatusValid } - -// AccountOptions are the options needed to create a new ACME account. -type AccountOptions struct { - Key *jose.JSONWebKey - Contact []string -} - -// AccountUpdateOptions are the options needed to update an existing ACME account. -type AccountUpdateOptions struct { - Contact []string - Status types.Status -} - -// toACME converts the internal Account type into the public acmeAccount -// type for presentation in the ACME protocol. -//func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) { -// return &Account{ -// Status: a.Status, -// Contact: a.Contact, -// Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), -// Key: a.Key, -// ID: a.ID, -// }, nil -//} diff --git a/acme/authority.go b/acme/authority.go index c0b6a732..d07f591a 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -28,16 +28,16 @@ type Interface interface { DeactivateAccount(ctx context.Context, accID string) (*Account, error) GetAccount(ctx context.Context, accID string) (*Account, error) GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error) - NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) - UpdateAccount(context.Context, string, []string) (*Account, error) + NewAccount(ctx context.Context, acc *Account) (*Account, error) + UpdateAccount(ctx context.Context, acc *Account) (*Account, error) - GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error) + GetAuthz(ctx context.Context, accID string, authzID string) (*Authorization, error) ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error) FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error) GetOrder(ctx context.Context, accID string, orderID string) (*Order, error) GetOrdersByAccount(ctx context.Context, accID string) ([]string, error) - NewOrder(ctx context.Context, oo OrderOptions) (*Order, error) + NewOrder(ctx context.Context, o *Order) (*Order, error) GetCertificate(string, string) ([]byte, error) @@ -140,22 +140,19 @@ func (a *Authority) UseNonce(ctx context.Context, nonce string) error { } // NewAccount creates, stores, and returns a new ACME account. -func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) { - a := NewAccount(ao) - if err := a.db.CreateAccount(ctx, a); err != nil { +func (a *Authority) NewAccount(ctx context.Context, acc *Account) (*Account, error) { + if err := a.db.CreateAccount(ctx, acc); err != nil { return ServerInternalErr(err) } return a, nil } // UpdateAccount updates an ACME account. -func (a *Authority) UpdateAccount(ctx context.Context, auo AccountUpdateOptions) (*Account, error) { - acc, err := a.db.GetAccount(ctx, auo.ID) - if err != nil { - return ServerInternalErr(err) - } - acc.Contact = auo.Contact - acc.Status = auo.Status +func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, error) { + /* + acc.Contact = auo.Contact + acc.Status = auo.Status + */ if err = a.db.UpdateAccount(ctx, acc); err != nil { return ServerInternalErr(err) } @@ -228,20 +225,19 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string } // NewOrder generates, stores, and returns a new ACME order. -func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) { +func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) { prov, err := ProvisionerFromContext(ctx) if err != nil { return nil, err } - return db.CreateOrder(ctx, &Order{ - AccountID: ops.AccountID, - ProvisionerID: prov.GetID(), - Backdate: a.backdate.Duration, - DefaultDuration: prov.DefaultTLSCertDuration(), - Identifiers: ops.Identifiers, - NotBefore: ops.NotBefore, - NotAfter: ops.NotAfter, - }) + o.DefaultDuration = prov.DefaultTLSCertDuration() + o.Backdate = a.backdate.Duration + o.ProvisionerID = prov.GetID() + + if err = db.CreateOrder(ctx, o); err != nil { + return nil, ServerInternalErr(err) + } + return o, nil } // FinalizeOrder attempts to finalize an order and generate a new certificate. @@ -271,7 +267,7 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs // GetAuthz retrieves and attempts to update the status on an ACME authz // before returning. -func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) { +func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authorization, error) { az, err := a.db.GetAuthorization(ctx, authzID) if err != nil { return nil, err @@ -316,13 +312,14 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j } // GetCertificate retrieves the Certificate by ID. -func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) { - cert, err := getCert(a.db, certID) +func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { + cert, err := a.db.GetCertificate(a.db, certID) if err != nil { return nil, err } - if accID != cert.AccountID { - return nil, UnauthorizedErr(errors.New("account does not own certificate")) + if cert.AccountID != accID { + log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) + return nil, UnauthorizedErr(errors.New("account does not own challenge")) } return cert.toACME(a.db, a.dir) } diff --git a/acme/authorization.go b/acme/authorization.go index a41950cd..f1ef0adc 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -1,4 +1,4 @@ -package types +package acme import ( "context" diff --git a/acme/certificate.go b/acme/certificate.go index 6a31c880..f088d93c 100644 --- a/acme/certificate.go +++ b/acme/certificate.go @@ -2,88 +2,28 @@ package acme import ( "crypto/x509" - "encoding/json" "encoding/pem" - "time" - "github.com/pkg/errors" "github.com/smallstep/nosql" ) -type certificate struct { - ID string `json:"id"` - Created time.Time `json:"created"` - AccountID string `json:"accountID"` - OrderID string `json:"orderID"` - Leaf []byte `json:"leaf"` - Intermediates []byte `json:"intermediates"` -} - -// CertOptions options with which to create and store a cert object. -type CertOptions struct { +// Certificate options with which to create and store a cert object. +type Certificate struct { + ID string AccountID string OrderID string Leaf *x509.Certificate Intermediates []*x509.Certificate } -func newCert(db nosql.DB, ops CertOptions) (*certificate, error) { - id, err := randID() - if err != nil { - return nil, err - } - - leaf := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: ops.Leaf.Raw, - }) - var intermediates []byte - for _, cert := range ops.Intermediates { - intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ +// ToACME encodes the entire X509 chain into a PEM list. +func (cert *Certificate) ToACME(db nosql.DB, dir *directory) ([]byte, error) { + var ret []byte + for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { + ret = append(ret, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", - Bytes: cert.Raw, + Bytes: c.Raw, })...) } - - cert := &certificate{ - ID: id, - AccountID: ops.AccountID, - OrderID: ops.OrderID, - Leaf: leaf, - Intermediates: intermediates, - Created: time.Now().UTC(), - } - certB, err := json.Marshal(cert) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate")) - } - - _, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB) - switch { - case err != nil: - return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate")) - case !swapped: - return nil, ServerInternalErr(errors.New("error storing certificate; " + - "value has changed since last read")) - default: - return cert, nil - } -} - -func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) { - return append(c.Leaf, c.Intermediates...), nil -} - -func getCert(db nosql.DB, id string) (*certificate, error) { - b, err := db.Get(certTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate")) - } - var cert certificate - if err := json.Unmarshal(b, &cert); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate")) - } - return &cert, nil + return ret, nil } diff --git a/acme/challenge.go b/acme/challenge.go index de178d6c..e7abaf64 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -1,4 +1,4 @@ -package types +package acme import ( "context" diff --git a/acme/db.go b/acme/db.go index 19449d76..846eed04 100644 --- a/acme/db.go +++ b/acme/db.go @@ -4,28 +4,28 @@ import "context" // DB is the DB interface expected by the step-ca ACME API. type DB interface { - CreateAccount(ctx context.Context, acc *types.Account) (*types.Account, error) - GetAccount(ctx context.Context, id string) (*types.Account, error) - GetAccountByKeyID(ctx context.Context, kid string) (*types.Account, error) - UpdateAccount(ctx context.Context, acc *types.Account) error + CreateAccount(ctx context.Context, acc *Account) (*Account, error) + GetAccount(ctx context.Context, id string) (*Account, error) + GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) + UpdateAccount(ctx context.Context, acc *Account) error - CreateNonce(ctx context.Context) (types.Nonce, error) - DeleteNonce(ctx context.Context, nonce types.Nonce) error + CreateNonce(ctx context.Context) (Nonce, error) + DeleteNonce(ctx context.Context, nonce Nonce) error - CreateAuthorization(ctx context.Context, authz *types.Authorization) error - GetAuthorization(ctx context.Context, id string) (*types.Authorization, error) - UpdateAuthorization(ctx context.Context, authz *types.Authorization) error + CreateAuthorization(ctx context.Context, az *Authorization) error + GetAuthorization(ctx context.Context, id string) (*Authorization, error) + UpdateAuthorization(ctx context.Context, az *Authorization) error - CreateCertificate(ctx context.Context, cert *types.Certificate) error - GetCertificate(ctx context.Context, id string) (*types.Certificate, error) + CreateCertificate(ctx context.Context, cert *Certificate) error + GetCertificate(ctx context.Context, id string) (*Certificate, error) - CreateChallenge(ctx context.Context, ch *types.Challenge) error - GetChallenge(ctx context.Context, id, authzID string) (*types.Challenge, error) - UpdateChallenge(ctx context.Context, ch *types.Challenge) error + CreateChallenge(ctx context.Context, ch *Challenge) error + GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error) + UpdateChallenge(ctx context.Context, ch *Challenge) error - CreateOrder(ctx context.Context, o *types.Order) error + CreateOrder(ctx context.Context, o *Order) error DeleteOrder(ctx context.Context, id string) error - GetOrder(ctx context.Context, id string) (*types.Order, error) + GetOrder(ctx context.Context, id string) (*Order, error) GetOrdersByAccountID(ctx context.Context, accountID string) error - UpdateOrder(ctx context.Context, o *types.Order) error + UpdateOrder(ctx context.Context, o *Order) error } diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 6e9ee8c0..e863c371 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -26,7 +26,7 @@ func (dba *dbAccount) clone() *dbAccount { } // CreateAccount imlements the AcmeDB.CreateAccount interface. -func (db *DB) CreateAccount(ctx context.Context, acc *types.Account) error { +func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { acc.ID, err = randID() if err != nil { return nil, err @@ -63,9 +63,13 @@ func (db *DB) CreateAccount(ctx context.Context, acc *types.Account) error { } // GetAccount retrieves an ACME account by ID. -func (db *DB) GetAccount(ctx context.Context, id string) (*types.Account, error) { +func (db *DB) GetAccount(ctx context.Context, id string) (*Account, error) { + acc, err := db.getDBAccount(ctx, id) + if err != nil { + return nil, err + } - return &types.Account{ + return &Account{ Status: dbacc.Status, Contact: dbacc.Contact, Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), @@ -75,7 +79,7 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*types.Account, error) } // GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). -func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*types.Account, error) { +func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { id, err := db.getAccountIDByKeyID(kid) if err != nil { return nil, err @@ -84,7 +88,7 @@ func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*types.Account } // UpdateAccount imlements the AcmeDB.UpdateAccount interface. -func (db *DB) UpdateAccount(ctx context.Context, acc *types.Account) error { +func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error { if len(acc.ID) == 0 { return ServerInternalErr(errors.New("id cannot be empty")) } @@ -99,7 +103,7 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *types.Account) error { nu.Status = acc.Status // If the status has changed to 'deactivated', then set deactivatedAt timestamp. - if acc.Status == types.StatusDeactivated && old.Status != types.Status.Deactivated { + if acc.Status == StatusDeactivated && old.Status != Status.Deactivated { nu.Deactivated = clock.Now() } diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go index e69de29b..a008db07 100644 --- a/acme/db/nosql/certificate.go +++ b/acme/db/nosql/certificate.go @@ -0,0 +1,121 @@ +package nosql + +import ( + "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql" +) + +type dbCert struct { + ID string `json:"id"` + Created time.Time `json:"created"` + AccountID string `json:"accountID"` + OrderID string `json:"orderID"` + Leaf []byte `json:"leaf"` + Intermediates []byte `json:"intermediates"` +} + +// CreateCertificate creates and stores an ACME certificate type. +func (db *DB) CreateCertificate(ctx context.Context, cert *Certificate) error { + cert.id, err = randID() + if err != nil { + return err + } + + leaf := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: ops.Leaf.Raw, + }) + var intermediates []byte + for _, cert := range ops.Intermediates { + intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })...) + } + + cert := &dbCert{ + ID: cert.ID, + AccountID: cert.AccountID, + OrderID: cert.OrderID, + Leaf: leaf, + Intermediates: intermediates, + Created: time.Now().UTC(), + } + return db.save(ctx, cert.ID, cert, nil, "certificate", certTable) +} + +// GetCertificate retrieves and unmarshals an ACME certificate type from the +// datastore. +func (db *DB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { + b, err := db.db.Get(certTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate")) + } + var dbCert certificate + if err := json.Unmarshal(b, &dbCert); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate")) + } + + leaf, err := parseCert(dbCert.Leaf) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf("error parsing leaf of ACME Certificate with ID '%s'", id)) + } + + intermediates, err := parseBundle(dbCert.Intermediates) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf("error parsing intermediate bundle of ACME Certificate with ID '%s'", id)) + } + + return &Certificate{ + ID: dbCert.ID, + AccountID: dbCert.AccountID, + OrderID: dbCert.OrderID, + Leaf: leaf, + Intermediates: intermediate, + } +} + +func parseCert(b []byte) (*x509.Certificate, error) { + block, rest := pem.Decode(dbCert.Leaf) + if block == nil || len(rest) > 0 { + return nil, errors.New("error decoding PEM block: contains unexpected data") + } + if block.Type != "CERTIFICATE" { + return nil, errors.New("error decoding PEM: block is not a certificate bundle") + } + var crt *x509.Certificate + crt, err = x509.ParseCertificate(block.Bytes) +} + +func parseBundle(b []byte) ([]*x509.Certificate, error) { + var block *pem.Block + var bundle []*x509.Certificate + for len(b) > 0 { + block, b = pem.Decode(b) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + return nil, errors.Errorf("error decoding PEM: file '%s' is not a certificate bundle", filename) + } + var crt *x509.Certificate + crt, err = x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrapf(err, "error parsing %s", filename) + } + bundle = append(bundle, crt) + } + if len(b) > 0 { + return nil, errors.Errorf("error decoding PEM: file '%s' contains unexpected data", filename) + } + return bundle, nil + +} diff --git a/acme/order.go b/acme/order.go index 16a0ead2..8879fed0 100644 --- a/acme/order.go +++ b/acme/order.go @@ -1,3 +1,4 @@ +package acme import ( "context" @@ -188,7 +189,7 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) } - cert, err := newCert(db, CertOptions{ + cert, err := db.CreateCertificate(ctx, &Certificate{ AccountID: o.AccountID, OrderID: o.ID, Leaf: certChain[0], diff --git a/acme/status.go b/acme/status.go index c98a506e..d9aae82d 100644 --- a/acme/status.go +++ b/acme/status.go @@ -1,4 +1,4 @@ -package types +package acme // Status represents an ACME status. type Status string From 2ae43ef2dc386ec7dc6a75d5bb65e07c1d4d3799 Mon Sep 17 00:00:00 2001 From: max furman Date: Sun, 28 Feb 2021 22:49:20 -0800 Subject: [PATCH 07/47] [acme db interface] wip errors --- acme/account.go | 16 +- acme/api/account.go | 7 +- acme/authority.go | 67 ++-- acme/authorization.go | 19 +- acme/certificate.go | 5 +- acme/challenge.go | 38 +- acme/common.go | 33 -- acme/db.go | 2 +- acme/db/nosql/account.go | 52 +-- acme/db/nosql/authz.go | 67 ++-- acme/db/nosql/certificate.go | 68 ++-- acme/db/nosql/challenge.go | 92 +++-- acme/db/nosql/nonce.go | 14 +- acme/db/nosql/nosql.go | 51 ++- acme/db/nosql/order.go | 36 +- acme/errors.go | 632 ++++++++++++-------------------- acme/order.go | 73 ++-- authority/provisioner/method.go | 3 +- 18 files changed, 562 insertions(+), 713 deletions(-) diff --git a/acme/account.go b/acme/account.go index a0f88d49..80cc66ef 100644 --- a/acme/account.go +++ b/acme/account.go @@ -1,9 +1,10 @@ package acme import ( + "crypto" + "encoding/base64" "encoding/json" - "github.com/pkg/errors" "go.step.sm/crypto/jose" ) @@ -11,7 +12,7 @@ import ( // attributes required for responses in the ACME protocol. type Account struct { Contact []string `json:"contact,omitempty"` - Status string `json:"status"` + Status Status `json:"status"` Orders string `json:"orders"` ID string `json:"-"` Key *jose.JSONWebKey `json:"-"` @@ -21,7 +22,7 @@ type Account struct { func (a *Account) ToLog() (interface{}, error) { b, err := json.Marshal(a) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging")) + return nil, ErrorWrap(ErrorServerInternalType, err, "error marshaling account for logging") } return string(b), nil } @@ -40,3 +41,12 @@ func (a *Account) GetKey() *jose.JSONWebKey { func (a *Account) IsValid() bool { return Status(a.Status) == StatusValid } + +// KeyToID converts a JWK to a thumbprint. +func KeyToID(jwk *jose.JSONWebKey) (string, error) { + kid, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return "", ErrorWrap(ErrorServerInternalType, err, "error generating jwk thumbprint") + } + return base64.RawURLEncoding.EncodeToString(kid), nil +} diff --git a/acme/api/account.go b/acme/api/account.go index 93f46651..ec2854cc 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -44,7 +44,7 @@ type UpdateAccountRequest struct { // IsDeactivateRequest returns true if the update request is a deactivation // request, false otherwise. func (u *UpdateAccountRequest) IsDeactivateRequest() bool { - return u.Status == acme.StatusDeactivated + return u.Status == string(acme.StatusDeactivated) } // Validate validates a update-account request body. @@ -59,7 +59,7 @@ func (u *UpdateAccountRequest) Validate() error { } return nil case len(u.Status) > 0: - if u.Status != acme.StatusDeactivated { + if u.Status != string(acme.StatusDeactivated) { return acme.MalformedErr(errors.Errorf("cannot update account "+ "status to %s, only deactivated", u.Status)) } @@ -110,9 +110,10 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } - if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{ + if acc, err = h.Auth.NewAccount(r.Context(), &acme.Account{ Key: jwk, Contact: nar.Contact, + Status: acme.StatusValid, }); err != nil { api.WriteError(w, err) return diff --git a/acme/authority.go b/acme/authority.go index d07f591a..77e031d0 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -2,10 +2,8 @@ package acme import ( "context" - "crypto" "crypto/tls" "crypto/x509" - "encoding/base64" "log" "net" "net/http" @@ -14,8 +12,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" - database "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" "go.step.sm/crypto/jose" ) @@ -49,7 +45,7 @@ type Interface interface { // Authority is the layer that handles all ACME interactions. type Authority struct { backdate provisioner.Duration - db nosql.DB + db DB dir *directory signAuth SignAuthority } @@ -57,8 +53,8 @@ type Authority struct { // AuthorityOptions required to create a new ACME Authority. type AuthorityOptions struct { Backdate provisioner.Duration - // DB is the database used by nosql. - DB nosql.DB + // DB storage backend that impements the acme.DB interface. + DB DB // DNS the host used to generate accurate ACME links. By default the authority // will use the Host from the request, so this value will only be used if // request.Host is empty. @@ -74,7 +70,7 @@ type AuthorityOptions struct { // // Deprecated: NewAuthority exists for hitorical compatibility and should not // be used. Use acme.New() instead. -func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { +func NewAuthority(db DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { return New(signAuth, AuthorityOptions{ DB: db, DNS: dns, @@ -84,19 +80,6 @@ func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Aut // New returns a new Authority that implements the ACME interface. func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { - if _, ok := ops.DB.(*database.SimpleDB); !ok { - // If it's not a SimpleDB then go ahead and bootstrap the DB with the - // necessary ACME tables. SimpleDB should ONLY be used for testing. - tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, - challengeTable, nonceTable, orderTable, ordersByAccountIDTable, - certTable} - for _, b := range tables { - if err := ops.DB.CreateTable(b); err != nil { - return nil, errors.Wrapf(err, "error creating table %s", - string(b)) - } - } - } return &Authority{ backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth, }, nil @@ -130,21 +113,21 @@ func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error } // NewNonce generates, stores, and returns a new ACME nonce. -func (a *Authority) NewNonce(ctx context.Context) (string, error) { +func (a *Authority) NewNonce(ctx context.Context) (Nonce, error) { return a.db.CreateNonce(ctx) } // UseNonce consumes the given nonce if it is valid, returns error otherwise. func (a *Authority) UseNonce(ctx context.Context, nonce string) error { - return a.db.DeleteNonce(ctx, nonce) + return a.db.DeleteNonce(ctx, Nonce(nonce)) } // NewAccount creates, stores, and returns a new ACME account. -func (a *Authority) NewAccount(ctx context.Context, acc *Account) (*Account, error) { +func (a *Authority) NewAccount(ctx context.Context, acc *Account) error { if err := a.db.CreateAccount(ctx, acc); err != nil { - return ServerInternalErr(err) + return ErrorWrap(ErrorServerInternalType, err, "newAccount: error creating account") } - return a, nil + return nil } // UpdateAccount updates an ACME account. @@ -153,8 +136,8 @@ func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, acc.Contact = auo.Contact acc.Status = auo.Status */ - if err = a.db.UpdateAccount(ctx, acc); err != nil { - return ServerInternalErr(err) + if err := a.db.UpdateAccount(ctx, acc); err != nil { + return nil, ErrorWrap(ErrorServerInternalType, err, "authority.UpdateAccount - database update failed" } return acc, nil } @@ -164,17 +147,9 @@ func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) return a.db.GetAccount(ctx, id) } -func keyToID(jwk *jose.JSONWebKey) (string, error) { - kid, err := jwk.Thumbprint(crypto.SHA256) - if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint")) - } - return base64.RawURLEncoding.EncodeToString(kid), nil -} - // GetAccountByKey returns the ACME associated with the jwk id. func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) { - kid, err := keyToID(jwk) + kid, err := KeyToID(jwk) if err != nil { return nil, err } @@ -200,12 +175,13 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) return nil, UnauthorizedErr(errors.New("provisioner does not own order")) } - if err = a.updateOrderStatus(ctx, o); err != nil { + if err = o.UpdateStatus(ctx, a.db); err != nil { return nil, err } - return o.toACME(ctx, a.db, a.dir) + return o, nil } +/* // GetOrdersByAccount returns the list of order urls owned by the account. func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { ordersByAccountMux.Lock() @@ -223,6 +199,7 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string } return ret, nil } +*/ // NewOrder generates, stores, and returns a new ACME order. func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) { @@ -234,7 +211,7 @@ func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) { o.Backdate = a.backdate.Duration o.ProvisionerID = prov.GetID() - if err = db.CreateOrder(ctx, o); err != nil { + if err = a.db.CreateOrder(ctx, o); err != nil { return nil, ServerInternalErr(err) } return o, nil @@ -258,8 +235,7 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) return nil, UnauthorizedErr(errors.New("provisioner does not own order")) } - o, err = o.Finalize(ctx, a.db, csr, a.signAuth, prov) - if err != nil { + if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil { return nil, Wrap(err, "error finalizing order") } return o, nil @@ -276,8 +252,7 @@ func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Autho log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) return nil, UnauthorizedErr(errors.New("account does not own authz")) } - az, err = az.UpdateStatus(ctx, a.db) - if err != nil { + if err = az.UpdateStatus(ctx, a.db); err != nil { return nil, Wrap(err, "error updating authz status") } return az, nil @@ -313,7 +288,7 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j // GetCertificate retrieves the Certificate by ID. func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { - cert, err := a.db.GetCertificate(a.db, certID) + cert, err := a.db.GetCertificate(ctx, certID) if err != nil { return nil, err } @@ -321,5 +296,5 @@ func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([ log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) return nil, UnauthorizedErr(errors.New("account does not own challenge")) } - return cert.toACME(a.db, a.dir) + return cert.ToACME(ctx) } diff --git a/acme/authorization.go b/acme/authorization.go index f1ef0adc..43095fb3 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -11,7 +11,7 @@ import ( // Authorization representst an ACME Authorization. type Authorization struct { Identifier *Identifier `json:"identifier"` - Status string `json:"status"` + Status Status `json:"status"` Expires string `json:"expires"` Challenges []*Challenge `json:"challenges"` Wildcard bool `json:"wildcard"` @@ -34,7 +34,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { now := time.Now().UTC() expiry, err := time.Parse(time.RFC3339, az.Expires) if err != nil { - return ServerInternalErr(errors.Wrap("error converting expiry string to time")) + return ServerInternalErr(errors.Wrap(err, "error converting expiry string to time")) } switch az.Status { @@ -46,16 +46,11 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { // check expiry if now.After(expiry) { az.Status = StatusInvalid - az.Error = MalformedErr(errors.New("authz has expired")) break } var isValid = false - for _, chID := range ba.Challenges { - ch, err := db.GetChallenge(ctx, chID, az.ID) - if err != nil { - return ServerInternalErr(err) - } + for _, ch := range az.Challenges { if ch.Status == StatusValid { isValid = true break @@ -66,10 +61,12 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { return nil } az.Status = StatusValid - az.Error = nil default: - return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) + return ServerInternalErr(errors.Errorf("unrecognized authorization status: %s", az.Status)) } - return ServerInternalErr(db.UpdateAuthorization(ctx, az)) + if err = db.UpdateAuthorization(ctx, az); err != nil { + return ServerInternalErr(err) + } + return nil } diff --git a/acme/certificate.go b/acme/certificate.go index f088d93c..356c0121 100644 --- a/acme/certificate.go +++ b/acme/certificate.go @@ -1,10 +1,9 @@ package acme import ( + "context" "crypto/x509" "encoding/pem" - - "github.com/smallstep/nosql" ) // Certificate options with which to create and store a cert object. @@ -17,7 +16,7 @@ type Certificate struct { } // ToACME encodes the entire X509 chain into a PEM list. -func (cert *Certificate) ToACME(db nosql.DB, dir *directory) ([]byte, error) { +func (cert *Certificate) ToACME(ctx context.Context) ([]byte, error) { var ret []byte for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { ret = append(ret, pem.EncodeToMemory(&pem.Block{ diff --git a/acme/challenge.go b/acme/challenge.go index e7abaf64..59ca454a 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -18,14 +18,13 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/nosql" "go.step.sm/crypto/jose" ) // Challenge represents an ACME response Challenge type. type Challenge struct { Type string `json:"type"` - Status string `json:"status"` + Status Status `json:"status"` Token string `json:"token"` Validated string `json:"validated,omitempty"` URL string `json:"url"` @@ -99,7 +98,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb // Update and store the challenge. ch.Status = StatusValid ch.Error = nil - ch.Validated = clock.Now() + ch.Validated = clock.Now().Format(time.RFC3339) return ServerInternalErr(db.UpdateChallenge(ctx, ch)) } @@ -107,11 +106,11 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, - ServerName: tc.Value, + ServerName: ch.Value, InsecureSkipVerify: true, // we expect a self-signed challenge certificate } - hostPort := net.JoinHostPort(tc.Value, "443") + hostPort := net.JoinHostPort(ch.Value, "443") conn, err := vo.tlsDial("tcp", hostPort, config) if err != nil { @@ -125,7 +124,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON if len(certs) == 0 { return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("%s "+ - "challenge for %s resulted in no certificates", tc.Type, tc.Value))) + "challenge for %s resulted in no certificates", ch.Type, ch.Value))) } if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { @@ -135,18 +134,18 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON leafCert := certs[0] - if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { + if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "leaf certificate must contain a single DNS name, %v", tc.Value))) + "leaf certificate must contain a single DNS name, %v", ch.Value))) } idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} foundIDPeAcmeIdentifierV1Obsolete := false - keyAuth, err := KeyAuthorization(tc.Token, jwk) + keyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { - return nil, err + return err } hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) @@ -173,9 +172,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON ch.Status = StatusValid ch.Error = nil - ch.Validated = clock.Now() + ch.Validated = clock.Now().Format(time.RFC3339) - return ServerInternalErr(db.UpdateChallenge(ctx, ch)) + if err = db.UpdateChallenge(ctx, ch); err != nil { + return ServerInternalErr(errors.Wrap(err, "tlsalpn01ValidateChallenge - error updating challenge")) + } + return nil } if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { @@ -192,12 +194,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))) } -func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com - domain := strings.TrimPrefix(dc.Value, "*.") + domain := strings.TrimPrefix(ch.Value, "*.") txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) if err != nil { @@ -205,9 +207,9 @@ func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JS "records for domain %s", domain))) } - expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) + expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { - return nil, err + return err } h := sha256.Sum256([]byte(expectedKeyAuth)) expected := base64.RawURLEncoding.EncodeToString(h[:]) @@ -226,7 +228,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JS // Update and store the challenge. ch.Status = StatusValid ch.Error = nil - ch.Validated = time.Now().UTC() + ch.Validated = clock.Now().UTC().Format(time.RFC3339) return ServerInternalErr(db.UpdateChallenge(ctx, ch)) } @@ -243,7 +245,7 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { } // storeError the given error to an ACME error and saves using the DB interface. -func (bc *baseChallenge) storeError(ctx context.Context, ch Challenge, db nosql.DB, err *Error) error { +func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { ch.Error = err.ToACME() if err := db.UpdateChallenge(ctx, ch); err != nil { return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge")) diff --git a/acme/common.go b/acme/common.go index a5a1fe09..1b268327 100644 --- a/acme/common.go +++ b/acme/common.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" - "go.step.sm/crypto/randutil" ) // Provisioner is an interface that implements a subset of the provisioner.Interface -- @@ -149,38 +148,6 @@ type SignAuthority interface { LoadProvisionerByID(string) (provisioner.Interface, error) } -// Identifier encodes the type that an order pertains to. -type Identifier struct { - Type string `json:"type"` - Value string `json:"value"` -} - -var ( - // StatusValid -- valid - StatusValid = "valid" - // StatusInvalid -- invalid - StatusInvalid = "invalid" - // StatusPending -- pending; e.g. an Order that is not ready to be finalized. - StatusPending = "pending" - // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid. - StatusDeactivated = "deactivated" - // StatusReady -- ready; e.g. for an Order that is ready to be finalized. - StatusReady = "ready" - //statusExpired = "expired" - //statusActive = "active" - //statusProcessing = "processing" -) - -var idLen = 32 - -func randID() (val string, err error) { - val, err = randutil.Alphanumeric(idLen) - if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID")) - } - return val, nil -} - // Clock that returns time in UTC rounded to seconds. type Clock int diff --git a/acme/db.go b/acme/db.go index 846eed04..dfbd30ce 100644 --- a/acme/db.go +++ b/acme/db.go @@ -4,7 +4,7 @@ import "context" // DB is the DB interface expected by the step-ca ACME API. type DB interface { - CreateAccount(ctx context.Context, acc *Account) (*Account, error) + CreateAccount(ctx context.Context, acc *Account) error GetAccount(ctx context.Context, id string) (*Account, error) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) UpdateAccount(ctx context.Context, acc *Account) error diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index e863c371..40961ce3 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -6,6 +6,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" nosqlDB "github.com/smallstep/nosql" "go.step.sm/crypto/jose" ) @@ -17,7 +18,7 @@ type dbAccount struct { Deactivated time.Time `json:"deactivated"` Key *jose.JSONWebKey `json:"key"` Contact []string `json:"contact,omitempty"` - Status string `json:"status"` + Status acme.Status `json:"status"` } func (dba *dbAccount) clone() *dbAccount { @@ -26,33 +27,34 @@ func (dba *dbAccount) clone() *dbAccount { } // CreateAccount imlements the AcmeDB.CreateAccount interface. -func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { +func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { + var err error acc.ID, err = randID() if err != nil { - return nil, err + return err } dba := &dbAccount{ ID: acc.ID, Key: acc.Key, Contact: acc.Contact, - Status: acc.Valid, + Status: acc.Status, Created: clock.Now(), } - kid, err := keyToID(dba.Key) + kid, err := acme.KeyToID(dba.Key) if err != nil { return err } kidB := []byte(kid) // Set the jwkID -> acme account ID index - _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID)) + _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID)) switch { case err != nil: - return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index")) + return errors.Wrap(err, "error storing keyID to accountID index") case !swapped: - return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) + return errors.Errorf("key-id to account-id index already exists") default: if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil { db.db.Del(accountByKeyIDTable, kidB) @@ -63,24 +65,24 @@ func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { } // GetAccount retrieves an ACME account by ID. -func (db *DB) GetAccount(ctx context.Context, id string) (*Account, error) { - acc, err := db.getDBAccount(ctx, id) +func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { + dbacc, err := db.getDBAccount(ctx, id) if err != nil { return nil, err } - return &Account{ + return &acme.Account{ Status: dbacc.Status, Contact: dbacc.Contact, - Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), + Orders: dir.getLink(ctx, OrdersByAccountLink, true, dbacc.ID), Key: dbacc.Key, ID: dbacc.ID, }, nil } // GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). -func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { - id, err := db.getAccountIDByKeyID(kid) +func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { + id, err := db.getAccountIDByKeyID(ctx, kid) if err != nil { return nil, err } @@ -88,9 +90,9 @@ func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, erro } // UpdateAccount imlements the AcmeDB.UpdateAccount interface. -func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error { +func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { if len(acc.ID) == 0 { - return ServerInternalErr(errors.New("id cannot be empty")) + return errors.New("id cannot be empty") } old, err := db.getDBAccount(ctx, acc.ID) @@ -99,24 +101,24 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error { } nu := old.clone() - nu.Contact = acc.contact + nu.Contact = acc.Contact nu.Status = acc.Status // If the status has changed to 'deactivated', then set deactivatedAt timestamp. - if acc.Status == StatusDeactivated && old.Status != Status.Deactivated { + if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated { nu.Deactivated = clock.Now() } - return db.save(ctx, old.ID, newdba, dba, "account", accountTable) + return db.save(ctx, old.ID, nu, old, "account", accountTable) } func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) + return "", errors.Wrapf(err, "account with key id %s not found", kid) } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) + return "", errors.Wrapf(err, "error loading key-account index") } return string(id), nil } @@ -126,14 +128,14 @@ func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { data, err := db.db.Get(accountTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) + return nil, errors.Wrapf(err, "account %s not found", id) } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) + return nil, errors.Wrapf(err, "error loading account %s", id) } - dbacc := new(account) + dbacc := new(dbAccount) if err = json.Unmarshal(data, dbacc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) + return nil, errors.Wrap(err, "error unmarshaling account") } return dbacc, nil } diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index a50d46f1..bc9f75bc 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) @@ -14,15 +15,15 @@ var defaultExpiryDuration = time.Hour * 24 // dbAuthz is the base authz type that others build from. type dbAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier *Identifier `json:"identifier"` - Status string `json:"status"` - Expires time.Time `json:"expires"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` - Error *Error `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier *acme.Identifier `json:"identifier"` + Status acme.Status `json:"status"` + Expires time.Time `json:"expires"` + Challenges []string `json:"challenges"` + Wildcard bool `json:"wildcard"` + Created time.Time `json:"created"` + Error *acme.Error `json:"error"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -35,33 +36,33 @@ func (ba *dbAuthz) clone() *dbAuthz { func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) { data, err := db.db.Get(authzTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id)) + return nil, errors.Wrapf(err, "authz %s not found", id) } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id)) + return nil, errors.Wrapf(err, "error loading authz %s", id) } var dbaz dbAuthz if err = json.Unmarshal(data, &dbaz); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dbAuthz")) + return nil, errors.Wrap(err, "error unmarshaling authz type into dbAuthz") } - return &dbaz + return &dbaz, nil } // GetAuthorization retrieves and unmarshals an ACME authz type from the database. // Implements acme.DB GetAuthorization interface. -func (db *DB) GetAuthorization(ctx context.Context, id string) (*types.Authorization, error) { - dbaz, err := getDBAuthz(id) +func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) { + dbaz, err := db.getDBAuthz(ctx, id) if err != nil { return nil, err } - var chs = make([]*Challenge, len(ba.Challenges)) + var chs = make([]*acme.Challenge, len(dbaz.Challenges)) for i, chID := range dbaz.Challenges { - chs[i], err = db.GetChallenge(ctx, chID) + chs[i], err = db.GetChallenge(ctx, chID, id) if err != nil { return nil, err } } - return &types.Authorization{ + return &acme.Authorization{ Identifier: dbaz.Identifier, Status: dbaz.Status, Challenges: chs, @@ -73,23 +74,24 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*types.Authoriza // CreateAuthorization creates an entry in the database for the Authorization. // Implements the acme.DB.CreateAuthorization interface. -func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) error { +func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error { if len(az.AccountID) == 0 { - return ServerInternalErr(errors.New("account-id cannot be empty")) + return errors.New("account-id cannot be empty") } if az.Identifier == nil { - return ServerInternalErr(errors.New("identifier cannot be nil")) + return errors.New("identifier cannot be nil") } + var err error az.ID, err = randID() if err != nil { - return nil, err + return err } now := clock.Now() dbaz := &dbAuthz{ ID: az.ID, AccountID: az.AccountID, - Status: types.StatusPending, + Status: acme.StatusPending, Created: now, Expires: now.Add(defaultExpiryDuration), Identifier: az.Identifier, @@ -97,9 +99,9 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) if strings.HasPrefix(az.Identifier.Value, "*.") { dbaz.Wildcard = true - dbaz.Identifier = Identifier{ - Value: strings.TrimPrefix(identifier.Value, "*."), - Type: identifier.Type, + dbaz.Identifier = &acme.Identifier{ + Value: strings.TrimPrefix(az.Identifier.Value, "*."), + Type: az.Identifier.Type, } } @@ -111,14 +113,14 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) } for _, typ := range chTypes { - ch, err := db.CreateChallenge(ctx, &types.Challenge{ + ch := &acme.Challenge{ AccountID: az.AccountID, AuthzID: az.ID, Value: az.Identifier.Value, Type: typ, - }) - if err != nil { - return nil, Wrapf(err, "error creating '%s' challenge", typ) + } + if err = db.CreateChallenge(ctx, ch); err != nil { + return errors.Wrapf(err, "error creating challenge") } chIDs = append(chIDs, ch.ID) @@ -129,9 +131,9 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) } // UpdateAuthorization saves an updated ACME Authorization to the database. -func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization) error { +func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error { if len(az.ID) == 0 { - return ServerInternalErr(errors.New("id cannot be empty")) + return errors.New("id cannot be empty") } old, err := db.getDBAuthz(ctx, az.ID) if err != nil { @@ -141,6 +143,5 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization) nu := old.clone() nu.Status = az.Status - nu.Error = az.Error return db.save(ctx, old.ID, nu, old, "authz", authzTable) } diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go index a008db07..ef766843 100644 --- a/acme/db/nosql/certificate.go +++ b/acme/db/nosql/certificate.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) @@ -21,25 +22,26 @@ type dbCert struct { } // CreateCertificate creates and stores an ACME certificate type. -func (db *DB) CreateCertificate(ctx context.Context, cert *Certificate) error { - cert.id, err = randID() +func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error { + var err error + cert.ID, err = randID() if err != nil { return err } leaf := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", - Bytes: ops.Leaf.Raw, + Bytes: cert.Leaf.Raw, }) var intermediates []byte - for _, cert := range ops.Intermediates { + for _, cert := range cert.Intermediates { intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, })...) } - cert := &dbCert{ + dbch := &dbCert{ ID: cert.ID, AccountID: cert.AccountID, OrderID: cert.OrderID, @@ -47,74 +49,80 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *Certificate) error { Intermediates: intermediates, Created: time.Now().UTC(), } - return db.save(ctx, cert.ID, cert, nil, "certificate", certTable) + return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) } // GetCertificate retrieves and unmarshals an ACME certificate type from the // datastore. -func (db *DB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { +func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) { b, err := db.db.Get(certTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id)) + return nil, errors.Wrapf(err, "certificate %s not found", id) } else if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate")) + return nil, errors.Wrap(err, "error loading certificate") } - var dbCert certificate - if err := json.Unmarshal(b, &dbCert); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate")) + dbC := new(dbCert) + if err := json.Unmarshal(b, dbC); err != nil { + return nil, errors.Wrap(err, "error unmarshaling certificate") } - leaf, err := parseCert(dbCert.Leaf) + leaf, err := parseCert(dbC.Leaf) if err != nil { - return nil, ServerInternalErr(errors.Wrapf("error parsing leaf of ACME Certificate with ID '%s'", id)) + return nil, errors.Wrapf(err, "error parsing leaf of ACME Certificate with ID '%s'", id) } - intermediates, err := parseBundle(dbCert.Intermediates) + intermediates, err := parseBundle(dbC.Intermediates) if err != nil { - return nil, ServerInternalErr(errors.Wrapf("error parsing intermediate bundle of ACME Certificate with ID '%s'", id)) + return nil, errors.Wrapf(err, "error parsing intermediate bundle of ACME Certificate with ID '%s'", id) } - return &Certificate{ - ID: dbCert.ID, - AccountID: dbCert.AccountID, - OrderID: dbCert.OrderID, + return &acme.Certificate{ + ID: dbC.ID, + AccountID: dbC.AccountID, + OrderID: dbC.OrderID, Leaf: leaf, - Intermediates: intermediate, - } + Intermediates: intermediates, + }, nil } func parseCert(b []byte) (*x509.Certificate, error) { - block, rest := pem.Decode(dbCert.Leaf) + block, rest := pem.Decode(b) if block == nil || len(rest) > 0 { return nil, errors.New("error decoding PEM block: contains unexpected data") } if block.Type != "CERTIFICATE" { return nil, errors.New("error decoding PEM: block is not a certificate bundle") } - var crt *x509.Certificate - crt, err = x509.ParseCertificate(block.Bytes) + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrap(err, "error parsing x509 certificate") + } + return cert, nil } func parseBundle(b []byte) ([]*x509.Certificate, error) { - var block *pem.Block - var bundle []*x509.Certificate + var ( + err error + block *pem.Block + bundle []*x509.Certificate + ) for len(b) > 0 { block, b = pem.Decode(b) if block == nil { break } if block.Type != "CERTIFICATE" { - return nil, errors.Errorf("error decoding PEM: file '%s' is not a certificate bundle", filename) + return nil, errors.New("error decoding PEM: data contains block that is not a certificate") } var crt *x509.Certificate crt, err = x509.ParseCertificate(block.Bytes) if err != nil { - return nil, errors.Wrapf(err, "error parsing %s", filename) + return nil, errors.Wrapf(err, "error parsing x509 certificate") } bundle = append(bundle, crt) } if len(b) > 0 { - return nil, errors.Errorf("error decoding PEM: file '%s' contains unexpected data", filename) + return nil, errors.New("error decoding PEM: unexpected data") } return bundle, nil diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index bd3be0d0..62513778 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -6,75 +6,69 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) -// ChallengeOptions is the type used to created a new Challenge. -type ChallengeOptions struct { - AccountID string - AuthzID string - Identifier Identifier -} - // dbChallenge is the base Challenge type that others build from. type dbChallenge struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - AuthzID string `json:"authzID"` - Type string `json:"type"` - Status string `json:"status"` - Token string `json:"token"` - Value string `json:"value"` - Validated time.Time `json:"validated"` - Created time.Time `json:"created"` - Error *AError `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + AuthzID string `json:"authzID"` + Type string `json:"type"` + Status acme.Status `json:"status"` + Token string `json:"token"` + Value string `json:"value"` + Validated string `json:"validated"` + Created time.Time `json:"created"` + Error *AError `json:"error"` } func (dbc *dbChallenge) clone() *dbChallenge { - u := *bc + u := *dbc return &u } func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { data, err := db.db.Get(challengeTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id)) + return nil, errors.Wrapf(err, "challenge %s not found", id) } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id)) + return nil, errors.Wrapf(err, "error loading challenge %s", id) } - dbch := new(baseChallenge) + dbch := new(dbChallenge) if err := json.Unmarshal(data, dbch); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into dbChallenge")) + return nil, errors.Wrap(err, "error unmarshaling dbChallenge") } - return dbch + return dbch, nil } // CreateChallenge creates a new ACME challenge data structure in the database. // Implements acme.DB.CreateChallenge interface. -func (db *DB) CreateChallenge(ctx context.context, ch *types.Challenge) error { +func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error { if len(ch.AuthzID) == 0 { - return ServerInternalError(errors.New("AuthzID cannot be empty")) + return errors.New("AuthzID cannot be empty") } if len(ch.AccountID) == 0 { - return ServerInternalError(errors.New("AccountID cannot be empty")) + return errors.New("AccountID cannot be empty") } if len(ch.Value) == 0 { - return ServerInternalError(errors.New("AccountID cannot be empty")) + return errors.New("AccountID cannot be empty") } // TODO: verify that challenge type is set and is one of expected types. if len(ch.Type) == 0 { - return ServerInternalError(errors.New("Type cannot be empty")) + return errors.New("Type cannot be empty") } + var err error ch.ID, err = randID() if err != nil { - return nil, Wrap(err, "error generating random id for ACME challenge") + return errors.Wrap(err, "error generating random id for ACME challenge") } ch.Token, err = randID() if err != nil { - return nil, Wrap(err, "error generating token for ACME challenge") + return errors.Wrap(err, "error generating token for ACME challenge") } dbch := &dbChallenge{ @@ -82,42 +76,40 @@ func (db *DB) CreateChallenge(ctx context.context, ch *types.Challenge) error { AuthzID: ch.AuthzID, AccountID: ch.AccountID, Value: ch.Value, - Status: types.StatusPending, + Status: acme.StatusPending, Token: ch.Token, Created: clock.Now(), Type: ch.Type, } - return dbch.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable) + return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable) } // GetChallenge retrieves and unmarshals an ACME challenge type from the database. // Implements the acme.DB GetChallenge interface. -func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*types.Challenge, error) { +func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) { dbch, err := db.getDBChallenge(ctx, id) if err != nil { - return err + return nil, err } - ch := &Challenge{ - Type: dbch.Type, - Status: dbch.Status, - Token: dbch.Token, - URL: dir.getLink(ctx, ChallengeLink, true, dbch.getID()), - ID: dbch.ID, - AuthzID: dbch.AuthzID(), - Error: dbch.Error, - } - if !dbch.Validated.IsZero() { - ac.Validated = dbch.Validated.Format(time.RFC3339) + ch := &acme.Challenge{ + Type: dbch.Type, + Status: dbch.Status, + Token: dbch.Token, + URL: dir.getLink(ctx, ChallengeLink, true, dbch.ID), + ID: dbch.ID, + AuthzID: dbch.AuthzID, + Error: dbch.Error, + Validated: dbch.Validated, } return ch, nil } // UpdateChallenge updates an ACME challenge type in the database. -func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error { +func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error { if len(ch.ID) == 0 { - return ServerInternalErr(errors.New("id cannot be empty")) + return errors.New("id cannot be empty") } old, err := db.getDBChallenge(ctx, ch.ID) if err != nil { @@ -129,9 +121,7 @@ func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error { // These should be the only values chaning in an Update request. nu.Status = ch.Status nu.Error = ch.Error - if nu.Status == types.StatusValid { - nu.Validated = clock.Now() - } + nu.Validated = ch.Validated return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) } diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index f8f57f89..02dcda6c 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -1,11 +1,13 @@ package nosql import ( + "context" "encoding/base64" "encoding/json" "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" nosqlDB "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) @@ -18,10 +20,10 @@ type dbNonce struct { // CreateNonce creates, stores, and returns an ACME replay-nonce. // Implements the acme.DB interface. -func (db *DB) CreateNonce() (Nonce, error) { +func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { _id, err := randID() if err != nil { - return nil, err + return "", err } id := base64.RawURLEncoding.EncodeToString([]byte(_id)) @@ -31,12 +33,12 @@ func (db *DB) CreateNonce() (Nonce, error) { } b, err := json.Marshal(n) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) + return "", errors.Wrap(err, "error marshaling nonce") } if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil { return "", err } - return Nonce(id), nil + return acme.Nonce(id), nil } // DeleteNonce verifies that the nonce is valid (by checking if it exists), @@ -59,9 +61,9 @@ func (db *DB) DeleteNonce(nonce string) error { switch { case nosqlDB.IsErrNotFound(err): - return BadNonceErr(nil) + return errors.New("not found") case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce)) + return errors.Wrapf(err, "error deleting nonce %s", nonce) default: return nil } diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index e11b92b2..0c040a89 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -3,9 +3,11 @@ package nosql import ( "context" "encoding/json" + "time" "github.com/pkg/errors" nosqlDB "github.com/smallstep/nosql" + "go.step.sm/crypto/randutil" ) var ( @@ -24,13 +26,26 @@ type DB struct { db nosqlDB.DB } +// New configures and returns a new ACME DB backend implemented using a nosql DB. +func New(db nosqlDB.DB) (*DB, error) { + tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, + challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable} + for _, b := range tables { + if err := db.CreateTable(b); err != nil { + return nil, errors.Wrapf(err, "error creating table %s", + string(b)) + } + } + return &DB{db}, nil +} + // save writes the new data to the database, overwriting the old data if it // existed. func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { newB, err := json.Marshal(nu) if err != nil { - return ServerInternalErr(errors.Wrapf(err, - "error marshaling new acme %s", typ)) + return errors.Wrapf(err, + "error marshaling new acme %s", typ) } var oldB []byte if old == nil { @@ -38,19 +53,39 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface } else { oldB, err = json.Marshal(old) if err != nil { - return ServerInternalErr(errors.Wrapf(err, - "error marshaling old acme %s", typ)) + return errors.Wrapf(err, + "error marshaling old acme %s", typ) } } - _, swapped, err := db.CmpAndSwap(table, []byte(id), oldB, newB) + _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) switch { case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error saving acme %s", typ)) + return errors.Wrapf(err, "error saving acme %s", typ) case !swapped: - return ServerInternalErr(errors.Errorf("error saving acme %s; "+ - "changed since last read", typ)) + return errors.Errorf("error saving acme %s; "+ + "changed since last read", typ) default: return nil } } + +var idLen = 32 + +func randID() (val string, err error) { + val, err = randutil.Alphanumeric(idLen) + if err != nil { + return "", errors.Wrap(err, "error generating random alphanumeric ID") + } + return val, nil +} + +// Clock that returns time in UTC rounded to seconds. +type Clock int + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Round(time.Second) +} + +var clock = new(Clock) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index a0ab60da..528619d4 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -17,51 +17,51 @@ var defaultOrderExpiry = time.Hour * 24 var ordersByAccountMux sync.Mutex type dbOrder struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - ProvisionerID string `json:"provisionerID"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires,omitempty"` - Status string `json:"status"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore,omitempty"` - NotAfter time.Time `json:"notAfter,omitempty"` - Error *Error `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - Certificate string `json:"certificate,omitempty"` + ID string `json:"id"` + AccountID string `json:"accountID"` + ProvisionerID string `json:"provisionerID"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires,omitempty"` + Status acme.Status `json:"status"` + Identifiers []acme.Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + Error *Error `json:"error,omitempty"` + Authorizations []string `json:"authorizations"` + Certificate string `json:"certificate,omitempty"` } // getDBOrder retrieves and unmarshals an ACME Order type from the database. func (db *DB) getDBOrder(id string) (*dbOrder, error) { b, err := db.db.Get(orderTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id)) + return nil, errors.Wrapf(err, "order %s not found", id) } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id)) + return nil, errors.Wrapf(err, "error loading order %s", id) } o := new(dbOrder) if err := json.Unmarshal(b, &o); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order")) + return nil, errors.Wrap(err, "error unmarshaling order") } return o, nil } // GetOrder retrieves an ACME Order from the database. -func (db *DB) GetOrder(id string) (*acme.Order, error) { +func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { dbo, err := db.getDBOrder(id) azs := make([]string, len(dbo.Authorizations)) for i, aid := range dbo.Authorizations { azs[i] = dir.getLink(ctx, AuthzLink, true, aid) } - o := &Order{ + o := &acme.Order{ Status: dbo.Status, Expires: dbo.Expires.Format(time.RFC3339), Identifiers: dbo.Identifiers, NotBefore: dbo.NotBefore.Format(time.RFC3339), NotAfter: dbo.NotAfter.Format(time.RFC3339), Authorizations: azs, - Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), + FinalizeURL: dir.getLink(ctx, FinalizeLink, true, o.ID), ID: dbo.ID, ProvisionerID: dbo.ProvisionerID, } diff --git a/acme/errors.go b/acme/errors.go index 9bd9c400..dc5b5568 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -1,410 +1,324 @@ +// Error represents an ACME package acme import ( + "fmt" + "github.com/pkg/errors" ) -// AccountDoesNotExistErr returns a new acme error. -func AccountDoesNotExistErr(err error) *Error { - return &Error{ - Type: accountDoesNotExistErr, - Detail: "Account does not exist", - Status: 400, - Err: err, - } -} - -// AlreadyRevokedErr returns a new acme error. -func AlreadyRevokedErr(err error) *Error { - return &Error{ - Type: alreadyRevokedErr, - Detail: "Certificate already revoked", - Status: 400, - Err: err, - } -} - -// BadCSRErr returns a new acme error. -func BadCSRErr(err error) *Error { - return &Error{ - Type: badCSRErr, - Detail: "The CSR is unacceptable", - Status: 400, - Err: err, - } -} - -// BadNonceErr returns a new acme error. -func BadNonceErr(err error) *Error { - return &Error{ - Type: badNonceErr, - Detail: "Unacceptable anti-replay nonce", - Status: 400, - Err: err, - } -} - -// BadPublicKeyErr returns a new acme error. -func BadPublicKeyErr(err error) *Error { - return &Error{ - Type: badPublicKeyErr, - Detail: "The jws was signed by a public key the server does not support", - Status: 400, - Err: err, - } -} - -// BadRevocationReasonErr returns a new acme error. -func BadRevocationReasonErr(err error) *Error { - return &Error{ - Type: badRevocationReasonErr, - Detail: "The revocation reason provided is not allowed by the server", - Status: 400, - Err: err, - } -} - -// BadSignatureAlgorithmErr returns a new acme error. -func BadSignatureAlgorithmErr(err error) *Error { - return &Error{ - Type: badSignatureAlgorithmErr, - Detail: "The JWS was signed with an algorithm the server does not support", - Status: 400, - Err: err, - } -} - -// CaaErr returns a new acme error. -func CaaErr(err error) *Error { - return &Error{ - Type: caaErr, - Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate", - Status: 400, - Err: err, - } -} - -// CompoundErr returns a new acme error. -func CompoundErr(err error) *Error { - return &Error{ - Type: compoundErr, - Detail: "Specific error conditions are indicated in the “subproblems” array", - Status: 400, - Err: err, - } -} - -// ConnectionErr returns a new acme error. -func ConnectionErr(err error) *Error { - return &Error{ - Type: connectionErr, - Detail: "The server could not connect to validation target", - Status: 400, - Err: err, - } -} - -// DNSErr returns a new acme error. -func DNSErr(err error) *Error { - return &Error{ - Type: dnsErr, - Detail: "There was a problem with a DNS query during identifier validation", - Status: 400, - Err: err, - } -} - -// ExternalAccountRequiredErr returns a new acme error. -func ExternalAccountRequiredErr(err error) *Error { - return &Error{ - Type: externalAccountRequiredErr, - Detail: "The request must include a value for the \"externalAccountBinding\" field", - Status: 400, - Err: err, - } -} - -// IncorrectResponseErr returns a new acme error. -func IncorrectResponseErr(err error) *Error { - return &Error{ - Type: incorrectResponseErr, - Detail: "Response received didn't match the challenge's requirements", - Status: 400, - Err: err, - } -} - -// InvalidContactErr returns a new acme error. -func InvalidContactErr(err error) *Error { - return &Error{ - Type: invalidContactErr, - Detail: "A contact URL for an account was invalid", - Status: 400, - Err: err, - } -} - -// MalformedErr returns a new acme error. -func MalformedErr(err error) *Error { - return &Error{ - Type: malformedErr, - Detail: "The request message was malformed", - Status: 400, - Err: err, - } -} - -// OrderNotReadyErr returns a new acme error. -func OrderNotReadyErr(err error) *Error { - return &Error{ - Type: orderNotReadyErr, - Detail: "The request attempted to finalize an order that is not ready to be finalized", - Status: 400, - Err: err, - } -} - -// RateLimitedErr returns a new acme error. -func RateLimitedErr(err error) *Error { - return &Error{ - Type: rateLimitedErr, - Detail: "The request exceeds a rate limit", - Status: 400, - Err: err, - } -} - -// RejectedIdentifierErr returns a new acme error. -func RejectedIdentifierErr(err error) *Error { - return &Error{ - Type: rejectedIdentifierErr, - Detail: "The server will not issue certificates for the identifier", - Status: 400, - Err: err, - } -} - -// ServerInternalErr returns a new acme error. -func ServerInternalErr(err error) *Error { - if err == nil { - return nil - } - return &Error{ - Type: serverInternalErr, - Detail: "The server experienced an internal error", - Status: 500, - Err: err, - } -} - -// NotImplemented returns a new acme error. -func NotImplemented(err error) *Error { - return &Error{ - Type: notImplemented, - Detail: "The requested operation is not implemented", - Status: 501, - Err: err, - } -} - -// TLSErr returns a new acme error. -func TLSErr(err error) *Error { - return &Error{ - Type: tlsErr, - Detail: "The server received a TLS error during validation", - Status: 400, - Err: err, - } -} - -// UnauthorizedErr returns a new acme error. -func UnauthorizedErr(err error) *Error { - return &Error{ - Type: unauthorizedErr, - Detail: "The client lacks sufficient authorization", - Status: 401, - Err: err, - } -} - -// UnsupportedContactErr returns a new acme error. -func UnsupportedContactErr(err error) *Error { - return &Error{ - Type: unsupportedContactErr, - Detail: "A contact URL for an account used an unsupported protocol scheme", - Status: 400, - Err: err, - } -} - -// UnsupportedIdentifierErr returns a new acme error. -func UnsupportedIdentifierErr(err error) *Error { - return &Error{ - Type: unsupportedIdentifierErr, - Detail: "An identifier is of an unsupported type", - Status: 400, - Err: err, - } -} - -// UserActionRequiredErr returns a new acme error. -func UserActionRequiredErr(err error) *Error { - return &Error{ - Type: userActionRequiredErr, - Detail: "Visit the “instance” URL and take actions specified there", - Status: 400, - Err: err, - } -} - -// ProbType is the type of the ACME problem. -type ProbType int +// ProblemType is the type of the ACME problem. +type ProblemType int const ( // The request specified an account that does not exist - accountDoesNotExistErr ProbType = iota + ErrorAccountDoesNotExistType ProblemType = iota // The request specified a certificate to be revoked that has already been revoked - alreadyRevokedErr + ErrorAlreadyRevokedType // The CSR is unacceptable (e.g., due to a short key) - badCSRErr + ErrorBadCSRType // The client sent an unacceptable anti-replay nonce - badNonceErr + ErrorBadNonceType // The JWS was signed by a public key the server does not support - badPublicKeyErr + ErrorBadPublicKeyType // The revocation reason provided is not allowed by the server - badRevocationReasonErr + ErrorBadRevocationReasonType // The JWS was signed with an algorithm the server does not support - badSignatureAlgorithmErr + ErrorBadSignatureAlgorithmType // Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate - caaErr + ErrorCaaType // Specific error conditions are indicated in the “subproblems” array. - compoundErr + ErrorCompoundType // The server could not connect to validation target - connectionErr + ErrorConnectionType // There was a problem with a DNS query during identifier validation - dnsErr + ErrorDNSType // The request must include a value for the “externalAccountBinding” field - externalAccountRequiredErr + ErrorExternalAccountRequiredType // Response received didn’t match the challenge’s requirements - incorrectResponseErr + ErrorIncorrectResponseType // A contact URL for an account was invalid - invalidContactErr + ErrorInvalidContactType // The request message was malformed - malformedErr + ErrorMalformedType // The request attempted to finalize an order that is not ready to be finalized - orderNotReadyErr + ErrorOrderNotReadyType // The request exceeds a rate limit - rateLimitedErr + ErrorRateLimitedType // The server will not issue certificates for the identifier - rejectedIdentifierErr + ErrorRejectedIdentifierType // The server experienced an internal error - serverInternalErr + ErrorServerInternalType // The server received a TLS error during validation - tlsErr + ErrorTLSType // The client lacks sufficient authorization - unauthorizedErr + ErrorUnauthorizedType // A contact URL for an account used an unsupported protocol scheme - unsupportedContactErr + ErrorUnsupportedContactType // An identifier is of an unsupported type - unsupportedIdentifierErr + ErrorUnsupportedIdentifierType // Visit the “instance” URL and take actions specified there - userActionRequiredErr + ErrorUserActionRequiredType // The operation is not implemented - notImplemented + ErrorNotImplementedType ) // String returns the string representation of the acme problem type, // fulfilling the Stringer interface. -func (ap ProbType) String() string { +func (ap ProblemType) String() string { switch ap { - case accountDoesNotExistErr: + case ErrorAccountDoesNotExistType: return "accountDoesNotExist" - case alreadyRevokedErr: + case ErrorAlreadyRevokedType: return "alreadyRevoked" - case badCSRErr: + case ErrorBadCSRType: return "badCSR" - case badNonceErr: + case ErrorBadNonceType: return "badNonce" - case badPublicKeyErr: + case ErrorBadPublicKeyType: return "badPublicKey" - case badRevocationReasonErr: + case ErrorBadRevocationReasonType: return "badRevocationReason" - case badSignatureAlgorithmErr: + case ErrorBadSignatureAlgorithmType: return "badSignatureAlgorithm" - case caaErr: + case ErrorCaaType: return "caa" - case compoundErr: + case ErrorCompoundType: return "compound" - case connectionErr: + case ErrorConnectionType: return "connection" - case dnsErr: + case ErrorDNSType: return "dns" - case externalAccountRequiredErr: + case ErrorExternalAccountRequiredType: return "externalAccountRequired" - case incorrectResponseErr: + case ErrorInvalidContactType: return "incorrectResponse" - case invalidContactErr: + case ErrorInvalidContactType: return "invalidContact" - case malformedErr: + case ErrorMalformedType: return "malformed" - case orderNotReadyErr: + case ErrorOrderNotReadyType: return "orderNotReady" - case rateLimitedErr: + case ErrorRateLimitedType: return "rateLimited" - case rejectedIdentifierErr: + case ErrorRejectedIdentifierType: return "rejectedIdentifier" - case serverInternalErr: + case ErrorServerInternalType: return "serverInternal" - case tlsErr: + case ErrorTLSType: return "tls" - case unauthorizedErr: + case ErrorUnauthorizedType: return "unauthorized" - case unsupportedContactErr: + case ErrorUnsupportedContactType: return "unsupportedContact" - case unsupportedIdentifierErr: + case ErrorUnsupportedIdentifierType: return "unsupportedIdentifier" - case userActionRequiredErr: + case ErrorUserActionRequiredType: return "userActionRequired" - case notImplemented: + case ErrorNotImplementedType: return "notImplemented" default: - return "unsupported type" + return fmt.Sprintf("unsupported type ACME error type %v", ap) + } +} + +type errorMetadata struct { + details string + status int + typ string + String string +} + +var ( + officialACMEPrefix = "urn:ietf:params:acme:error:" + stepACMEPrefix = "urn:step:acme:error:" + errorServerInternalMetadata = errorMetadata{ + ErrorAccountDoesNotExistType: { + typ: officialACMEPrefix + ErrorServerInternalType.String(), + details: "The server experienced an internal error", + status: 500, + }, + } + errorMap = [ProblemType]errorMetadata{ + ErrorAccountDoesNotExistType: { + typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(), + details: "Account does not exist", + status: 400, + }, + ErrorAlreadyRevokedType: { + typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(), + details: "Certificate already Revoked", + status: 400, + }, + ErrorBadCSRType: { + typ: officialACMEPrefix + ErrorBadCSRType.String(), + details: "The CSR is unacceptable", + status: 400, + }, + ErrorBadNonceType: { + typ: officialACMEPrefix + ErrorBadNonceType.String(), + details: "Unacceptable anti-replay nonce", + status: 400, + }, + ErrorBadPublicKeyType: { + typ: officialACMEPrefix + ErrorBadPublicKeyType.String(), + details: "The jws was signed by a public key the server does not support", + status: 400, + }, + ErrorBadRevocationReasonType: { + typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(), + details: "The revocation reason provided is not allowed by the server", + status: 400, + }, + ErrorBadSignatureAlgorithmType: { + typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(), + details: "The JWS was signed with an algorithm the server does not support", + status: 400, + }, + ErrorCaaType: { + typ: officialACMEPrefix + ErrorCaaType.String(), + details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate", + status: 400, + }, + ErrorCompoundType: { + typ: officialACMEPrefix + ErrorCompoundType.String(), + details: "Specific error conditions are indicated in the “subproblems” array", + status: 400, + }, + ErrorConnectionType: { + typ: officialACMEPrefix + ErrorConnectionType.String(), + details: "The server could not connect to validation target", + status: 400, + }, + ErrorDNSType: { + typ: officialACMEPrefix + ErrorDNSType.String(), + details: "There was a problem with a DNS query during identifier validation", + status: 400, + }, + ErrorExternalAccountRequiredType: { + typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(), + details: "The request must include a value for the \"externalAccountBinding\" field", + status: 400, + }, + ErrorIncorrectResponseType: { + typ: officialACMEPrefix + ErrorIncorrectResponseType.String(), + details: "Response received didn't match the challenge's requirements", + status: 400, + }, + ErrorInvalidContactType: { + typ: officialACMEPrefix + ErrorInvalidContactType.String(), + details: "A contact URL for an account was invalid", + status: 400, + }, + ErrorMalformedType: { + typ: officialACMEPrefix + ErrorMalformedType.String(), + details: "The request message was malformed", + status: 400, + }, + ErrorOrderNotReadyType: { + typ: officialACMEPrefix + ErrorOrderNotReadyType.String(), + details: "The request attempted to finalize an order that is not ready to be finalized", + status: 400, + }, + ErrorRateLimitedType: { + typ: officialACMEPrefix + ErrorRateLimitedType.String(), + details: "The request exceeds a rate limit", + status: 400, + }, + ErrorRejectedIdentifierType: { + typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(), + details: "The server will not issue certificates for the identifier", + status: 400, + }, + ErrorNotImplementedType: { + typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(), + details: "The requested operation is not implemented", + status: 501, + }, + ErrorTLSType: { + typ: officialACMEPrefix + ErrorTLSType.String(), + details: "The server received a TLS error during validation", + status: 400, + }, + ErrorUnauthorizedType: { + typ: officialACMEPrefix + ErrorUnauthorizedType.String(), + details: "The client lacks sufficient authorization", + status: 401, + }, + ErrorUnsupportedContactType: { + typ: officialACMEPrefix + ErrorUnsupportedContactType.String(), + details: "A contact URL for an account used an unsupported protocol scheme", + status: 400, + }, + ErrorUnsupportedIdentifierType: { + typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(), + details: "An identifier is of an unsupported type", + status: 400, + }, + ErrorUserActionRequiredType: { + typ: officialACMEPrefix + ErrorUserActionRequiredType.String(), + details: "Visit the “instance” URL and take actions specified there", + status: 400, + }, + ErrorServerInternalType: errorServerInternalMetadata, } -} +) -// Error is an ACME error type complete with problem document. +// Error represents an ACME type Error struct { - Type ProbType - Detail string - Err error - Status int - Sub []*Error - Identifier *Identifier + Type string `json:"type"` + Detail string `json:"detail"` + Subproblems []interface{} `json:"subproblems,omitempty"` + Identifier interface{} `json:"identifier,omitempty"` + Err error `json:"-"` + Status int `json:"-"` +} + +func NewError(pt ProblemType, msg string, args ...interface{}) *Error { + meta, ok := errorMetadata[typ] + if !ok { + meta = errorServerInternalMetadata + return &Error{ + Type: meta.typ, + Details: meta.details, + Status: meta.Status, + Err: errors.Errorf("unrecognized problemType %v", pt), + } + } + + return &Error{ + Type: meta.typ, + Details: meta.details, + Status: meta.status, + Err: errors.Errorf(msg, args...), + } } -// Wrap attempts to wrap the internal error. -func Wrap(err error, wrap string) *Error { +// ErrorWrap attempts to wrap the internal error. +func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Error { switch e := err.(type) { case nil: return nil case *Error: if e.Err == nil { - e.Err = errors.New(wrap + "; " + e.Detail) + e.Err = errors.Errorf(msg+"; "+e.Detail, args...) } else { - e.Err = errors.Wrap(e.Err, wrap) + e.Err = errors.Wrapf(e.Err, msg, args...) } return e default: - return ServerInternalErr(errors.Wrap(err, wrap)) + return NewError(ErrorServerInternalType, msg, args...) } } -// Error implements the error interface. +// StatusCode returns the status code and implements the StatusCoder interface. +func (e *Error) StatusCode() int { + return e.Status +} + +// Error allows AError to implement the error interface. func (e *Error) Error() string { - if e.Err == nil { - return e.Detail - } - return e.Err.Error() + return e.Detail } // Cause returns the internal error and implements the Causer interface. @@ -414,71 +328,3 @@ func (e *Error) Cause() error { } return e.Err } - -// Official returns true if this error's type is listed in §6.7 of RFC 8555. -// Error types in §6.7 are registered under IETF urn namespace: -// -// "urn:ietf:params:acme:error:" -// -// and should include the namespace as a prefix when appearing as a problem -// document. -// -// RFC 8555 also says: -// -// This list is not exhaustive. The server MAY return errors whose -// "type" field is set to a URI other than those defined above. Servers -// MUST NOT use the ACME URN namespace for errors not listed in the -// appropriate IANA registry (see Section 9.6). Clients SHOULD display -// the "detail" field of all errors. -// -// In this case Official returns `false` so that a different namespace can -// be used. -func (e *Error) Official() bool { - return e.Type != notImplemented -} - -// ToACME returns an acme representation of the problem type. -// For official errors, the IETF ACME namespace is prepended to the error type. -// For our own errors, we use an (yet) unregistered smallstep acme namespace. -func (e *Error) ToACME() *AError { - prefix := "urn:step:acme:error" - if e.Official() { - prefix = "urn:ietf:params:acme:error:" - } - ae := &AError{ - Type: prefix + e.Type.String(), - Detail: e.Error(), - Status: e.Status, - } - if e.Identifier != nil { - ae.Identifier = *e.Identifier - } - for _, p := range e.Sub { - ae.Subproblems = append(ae.Subproblems, p.ToACME()) - } - return ae -} - -// StatusCode returns the status code and implements the StatusCode interface. -func (e *Error) StatusCode() int { - return e.Status -} - -// AError is the error type as seen in acme request/responses. -type AError struct { - Type string `json:"type"` - Detail string `json:"detail"` - Identifier interface{} `json:"identifier,omitempty"` - Subproblems []interface{} `json:"subproblems,omitempty"` - Status int `json:"-"` -} - -// Error allows AError to implement the error interface. -func (ae *AError) Error() string { - return ae.Detail -} - -// StatusCode returns the status code and implements the StatusCode interface. -func (ae *AError) StatusCode() int { - return ae.Status -} diff --git a/acme/order.go b/acme/order.go index 8879fed0..01d3bc20 100644 --- a/acme/order.go +++ b/acme/order.go @@ -13,18 +13,25 @@ import ( "go.step.sm/crypto/x509util" ) +// Identifier encodes the type that an order pertains to. +type Identifier struct { + Type string `json:"type"` + Value string `json:"value"` +} + // Order contains order metadata for the ACME protocol order type. type Order struct { - Status string `json:"status"` + Status Status `json:"status"` Expires string `json:"expires,omitempty"` Identifiers []Identifier `json:"identifiers"` NotBefore string `json:"notBefore,omitempty"` NotAfter string `json:"notAfter,omitempty"` Error interface{} `json:"error,omitempty"` Authorizations []string `json:"authorizations"` - Finalize string `json:"finalize"` + FinalizeURL string `json:"finalize"` Certificate string `json:"certificate,omitempty"` ID string `json:"-"` + AccountID string `json:"-"` ProvisionerID string `json:"-"` DefaultDuration time.Duration `json:"-"` Backdate time.Duration `json:"-"` @@ -45,7 +52,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { now := time.Now().UTC() expiry, err := time.Parse(time.RFC3339, o.Expires) if err != nil { - return ServerInternalErr(errors.Wrap("error converting expiry string to time")) + return ServerInternalErr(errors.Wrap(err, "order.UpdateStatus - error converting expiry string to time")) } switch o.Status { @@ -69,7 +76,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { break } - var count = map[string]int{ + var count = map[Status]int{ StatusValid: 0, StatusInvalid: 0, StatusPending: 0, @@ -77,10 +84,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { for _, azID := range o.Authorizations { az, err := db.GetAuthorization(ctx, azID) if err != nil { - return false, err + return err } - if az, err = az.UpdateStatus(db); err != nil { - return false, err + if err = az.UpdateStatus(ctx, db); err != nil { + return err } st := az.Status count[st]++ @@ -98,20 +105,19 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { o.Status = StatusReady default: - return nil, ServerInternalErr(errors.New("unexpected authz status")) + return ServerInternalErr(errors.New("unexpected authz status")) } default: - return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) + return ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) } return db.UpdateOrder(ctx, o) } -// finalize signs a certificate if the necessary conditions for Order completion +// Finalize signs a certificate if the necessary conditions for Order completion // have been met. -func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { - var err error - if o, err = o.UpdateStatus(db); err != nil { - return nil, err +func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { + if err := o.UpdateStatus(ctx, db); err != nil { + return err } switch o.Status { @@ -124,7 +130,7 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth case StatusReady: break default: - return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) + return ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) } // RFC8555: The CSR MUST indicate the exact same set of requested @@ -135,7 +141,7 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth if csr.Subject.CommonName != "" { csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) } - csr.DNSNames = uniqueLowerNames(csr.DNSNames) + csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames) orderNames := make([]string, len(o.Identifiers)) for i, n := range o.Identifiers { orderNames[i] = n.Value @@ -148,13 +154,13 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth // absence of other SANs as they will only be set if the templates allows // them. if len(csr.DNSNames) != len(orderNames) { - return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + return BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) } sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) for i := range csr.DNSNames { if csr.DNSNames[i] != orderNames[i] { - return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + return BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) } sans[i] = x509util.SubjectAlternativeName{ Type: x509util.DNSType, @@ -163,10 +169,10 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth } // Get authorizations from the ACME provisioner. - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) + return ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) } // Template data @@ -176,27 +182,36 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) + return ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) } signOps = append(signOps, templateOptions) - // Create and store a new certificate. + nbf, err := time.Parse(time.RFC3339, o.NotBefore) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error parsing order NotBefore")) + } + naf, err := time.Parse(time.RFC3339, o.NotAfter) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error parsing order NotAfter")) + } + + // Sign a new certificate. certChain, err := auth.Sign(csr, provisioner.SignOptions{ - NotBefore: provisioner.NewTimeDuration(o.NotBefore), - NotAfter: provisioner.NewTimeDuration(o.NotAfter), + NotBefore: provisioner.NewTimeDuration(nbf), + NotAfter: provisioner.NewTimeDuration(naf), }, signOps...) if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) + return ServerInternalErr(errors.Wrapf(err, "error signing certificate for order %s", o.ID)) } - cert, err := db.CreateCertificate(ctx, &Certificate{ + cert := &Certificate{ AccountID: o.AccountID, OrderID: o.ID, Leaf: certChain[0], Intermediates: certChain[1:], - }) - if err != nil { - return nil, err + } + if err := db.CreateCertificate(ctx, cert); err != nil { + return err } o.Certificate = cert.ID diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index 775ed96f..f5cd5221 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -56,8 +56,7 @@ func NewContextWithMethod(ctx context.Context, method Method) context.Context { return context.WithValue(ctx, methodKey{}, method) } -// MethodFromContext returns the Method saved in ctx. Returns Sign if the given -// context has no Method associated with it. +// MethodFromContext returns the Method saved in ctx. func MethodFromContext(ctx context.Context) Method { m, _ := ctx.Value(methodKey{}).(Method) return m From 03ba229bcbf6acbb76fb2efd056041b5f6ac38d8 Mon Sep 17 00:00:00 2001 From: max furman Date: Sun, 28 Feb 2021 23:33:18 -0800 Subject: [PATCH 08/47] [acme db interface] wip more errors --- acme/authority.go | 45 +++++++++--------- acme/authorization.go | 10 ++-- acme/challenge.go | 106 ++++++++++++++++++++++-------------------- acme/common.go | 11 ++--- acme/directory.go | 4 +- acme/errors.go | 29 ++++++------ acme/order.go | 35 +++++++------- 7 files changed, 123 insertions(+), 117 deletions(-) diff --git a/acme/authority.go b/acme/authority.go index 77e031d0..098c48d4 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -10,7 +10,6 @@ import ( "net/url" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" ) @@ -125,7 +124,7 @@ func (a *Authority) UseNonce(ctx context.Context, nonce string) error { // NewAccount creates, stores, and returns a new ACME account. func (a *Authority) NewAccount(ctx context.Context, acc *Account) error { if err := a.db.CreateAccount(ctx, acc); err != nil { - return ErrorWrap(ErrorServerInternalType, err, "newAccount: error creating account") + return ErrorWrap(ErrorServerInternalType, err, "error creating account") } return nil } @@ -137,14 +136,18 @@ func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, acc.Status = auo.Status */ if err := a.db.UpdateAccount(ctx, acc); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "authority.UpdateAccount - database update failed" + return nil, ErrorWrap(ErrorServerInternalType, err, "error updating account") } return acc, nil } // GetAccount returns an ACME account. func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { - return a.db.GetAccount(ctx, id) + acc, err := a.db.GetAccount(ctx, id) + if err != nil { + return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving account") + } + return acc, nil } // GetAccountByKey returns the ACME associated with the jwk id. @@ -165,18 +168,18 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order } o, err := a.db.GetOrder(ctx, orderID) if err != nil { - return nil, ServerInternalErr(err) + return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving order") } if accID != o.AccountID { log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) - return nil, UnauthorizedErr(errors.New("account does not own order")) + return nil, NewError(ErrorUnauthorizedType, "account does not own order") } if prov.GetID() != o.ProvisionerID { log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) - return nil, UnauthorizedErr(errors.New("provisioner does not own order")) + return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") } if err = o.UpdateStatus(ctx, a.db); err != nil { - return nil, err + return nil, ErrorWrap(ErrorServerInternalType, err, "error updating order") } return o, nil } @@ -212,7 +215,7 @@ func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) { o.ProvisionerID = prov.GetID() if err = a.db.CreateOrder(ctx, o); err != nil { - return nil, ServerInternalErr(err) + return nil, ErrorWrap(ErrorServerInternalType, err, "error creating order") } return o, nil } @@ -225,18 +228,18 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs } o, err := a.db.GetOrder(ctx, orderID) if err != nil { - return nil, err + return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving order") } if accID != o.AccountID { log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) - return nil, UnauthorizedErr(errors.New("account does not own order")) + return nil, NewError(ErrorUnauthorizedType, "account does not own order") } if prov.GetID() != o.ProvisionerID { log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) - return nil, UnauthorizedErr(errors.New("provisioner does not own order")) + return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") } if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil { - return nil, Wrap(err, "error finalizing order") + return nil, ErrorWrap(ErrorServerInternalType, err, "error finalizing order") } return o, nil } @@ -246,14 +249,14 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authorization, error) { az, err := a.db.GetAuthorization(ctx, authzID) if err != nil { - return nil, err + return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving authorization") } if accID != az.AccountID { log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) - return nil, UnauthorizedErr(errors.New("account does not own authz")) + return nil, NewError(ErrorUnauthorizedType, "account does not own order") } if err = az.UpdateStatus(ctx, a.db); err != nil { - return nil, Wrap(err, "error updating authz status") + return nil, ErrorWrap(ErrorServerInternalType, err, "error updating authorization status") } return az, nil } @@ -262,11 +265,11 @@ func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Autho func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { ch, err := a.db.GetChallenge(ctx, chID, "todo") if err != nil { - return nil, err + return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving challenge") } if accID != ch.AccountID { log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, ch.AccountID) - return nil, UnauthorizedErr(errors.New("account does not own challenge")) + return nil, NewError(ErrorUnauthorizedType, "account does not own order") } client := http.Client{ Timeout: time.Duration(30 * time.Second), @@ -281,7 +284,7 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j return tls.DialWithDialer(dialer, network, addr, config) }, }); err != nil { - return nil, Wrap(err, "error attempting challenge validation") + return nil, ErrorWrap(ErrorServerInternalType, err, "error validating challenge") } return ch, nil } @@ -290,11 +293,11 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { cert, err := a.db.GetCertificate(ctx, certID) if err != nil { - return nil, err + return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving certificate") } if cert.AccountID != accID { log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) - return nil, UnauthorizedErr(errors.New("account does not own challenge")) + return nil, NewError(ErrorUnauthorizedType, "account does not own order") } return cert.ToACME(ctx) } diff --git a/acme/authorization.go b/acme/authorization.go index 43095fb3..ef230286 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "time" - - "github.com/pkg/errors" ) // Authorization representst an ACME Authorization. @@ -23,7 +21,7 @@ type Authorization struct { func (az *Authorization) ToLog() (interface{}, error) { b, err := json.Marshal(az) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) + return nil, ErrorInternalServerWrap(err, "error marshaling authz for logging") } return string(b), nil } @@ -34,7 +32,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { now := time.Now().UTC() expiry, err := time.Parse(time.RFC3339, az.Expires) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error converting expiry string to time")) + return ErrorInternalServerWrap(err, "error converting expiry string to time") } switch az.Status { @@ -62,11 +60,11 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { } az.Status = StatusValid default: - return ServerInternalErr(errors.Errorf("unrecognized authorization status: %s", az.Status)) + return NewError(ErrorServerInternalType, "unrecognized authorization status: %s", az.Status) } if err = db.UpdateAuthorization(ctx, az); err != nil { - return ServerInternalErr(err) + return ErrorInternalServerWrap(err, "error updating authorization") } return nil } diff --git a/acme/challenge.go b/acme/challenge.go index 59ca454a..05987427 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -17,29 +17,28 @@ import ( "strings" "time" - "github.com/pkg/errors" "go.step.sm/crypto/jose" ) // Challenge represents an ACME response Challenge type. type Challenge struct { - Type string `json:"type"` - Status Status `json:"status"` - Token string `json:"token"` - Validated string `json:"validated,omitempty"` - URL string `json:"url"` - Error *AError `json:"error,omitempty"` - ID string `json:"-"` - AuthzID string `json:"-"` - AccountID string `json:"-"` - Value string `json:"-"` + Type string `json:"type"` + Status Status `json:"status"` + Token string `json:"token"` + Validated string `json:"validated,omitempty"` + URL string `json:"url"` + Error *Error `json:"error,omitempty"` + ID string `json:"-"` + AuthzID string `json:"-"` + AccountID string `json:"-"` + Value string `json:"-"` } // ToLog enables response logging. func (ch *Challenge) ToLog() (interface{}, error) { b, err := json.Marshal(ch) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) + return nil, ErrorInternalServerWrap(err, "error marshaling challenge for logging") } return string(b), nil } @@ -61,7 +60,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, case "tls-alpn-01": return tlsalpn01Validate(ctx, ch, db, jwk, vo) default: - return ServerInternalErr(errors.Errorf("unexpected challenge type '%s'", ch.Type)) + return NewError(ErrorServerInternalType, "unexpected challenge type '%s'", ch.Type) } } @@ -70,19 +69,19 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb resp, err := vo.httpGet(url) if err != nil { - return storeError(ctx, ch, db, ConnectionErr(errors.Wrapf(err, - "error doing http GET for url %s", url))) + return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, + "error doing http GET for url %s", url)) } if resp.StatusCode >= 400 { - return storeError(ctx, ch, db, ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", - url, resp.StatusCode))) + return storeError(ctx, ch, db, NewError(ErrorConnectionType, + "error doing http GET for url %s with status code %d", url, resp.StatusCode)) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - return ServerInternalErr(errors.Wrapf(err, "error reading "+ - "response body for url %s", url)) + return ErrorInternalServerWrap(err, "error reading "+ + "response body for url %s", url) } keyAuth := strings.Trim(string(body), "\r\n") @@ -91,8 +90,8 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return err } if keyAuth != expected { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got %s", expected, keyAuth))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got %s", expected, keyAuth)) } // Update and store the challenge. @@ -100,7 +99,10 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb ch.Error = nil ch.Validated = clock.Now().Format(time.RFC3339) - return ServerInternalErr(db.UpdateChallenge(ctx, ch)) + if err = db.UpdateChallenge(ctx, ch); err != nil { + return ErrorInternalServerWrap(err, "error updating challenge") + } + return nil } func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { @@ -114,8 +116,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON conn, err := vo.tlsDial("tcp", hostPort, config) if err != nil { - return storeError(ctx, ch, db, ConnectionErr(errors.Wrapf(err, - "error doing TLS dial for %s", hostPort))) + return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, + "error doing TLS dial for %s", hostPort)) } defer conn.Close() @@ -123,20 +125,20 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON certs := cs.PeerCertificates if len(certs) == 0 { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("%s "+ - "challenge for %s resulted in no certificates", ch.Type, ch.Value))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "%s challenge for %s resulted in no certificates", ch.Type, ch.Value)) } if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("cannot "+ - "negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) } leafCert := certs[0] if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "leaf certificate must contain a single DNS name, %v", ch.Value))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)) } idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} @@ -152,22 +154,23 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON for _, ext := range leafCert.Extensions { if idPeAcmeIdentifier.Equal(ext.Id) { if !ext.Critical { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ - "certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical"))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) } var extValue []byte rest, err := asn1.Unmarshal(ext.Value, &extValue) if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ - "certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value"))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) } if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "expected acmeValidationV1 extension value %s for this challenge but got %s", - hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue))) } ch.Status = StatusValid @@ -175,7 +178,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ServerInternalErr(errors.Wrap(err, "tlsalpn01ValidateChallenge - error updating challenge")) + return ErrorInternalServerWrap(err, "tlsalpn01ValidateChallenge - error updating challenge") } return nil } @@ -186,12 +189,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON } if foundIDPeAcmeIdentifierV1Obsolete { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ - "certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) } - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect "+ - "certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { @@ -203,8 +206,8 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) if err != nil { - return storeError(ctx, ch, db, DNSErr(errors.Wrapf(err, "error looking up TXT "+ - "records for domain %s", domain))) + return storeError(ctx, ch, db, ErrorWrap(ErrorDNSType, err, + "error looking up TXT records for domain %s", domain)) } expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk) @@ -221,8 +224,8 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK } } if !found { - return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ - "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))) + return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords)) } // Update and store the challenge. @@ -230,7 +233,10 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK ch.Error = nil ch.Validated = clock.Now().UTC().Format(time.RFC3339) - return ServerInternalErr(db.UpdateChallenge(ctx, ch)) + if err = db.UpdateChallenge(ctx, ch); err != nil { + return ErrorInternalServerWrap(err, "error updating challenge") + } + return nil } // KeyAuthorization creates the ACME key authorization value from a token @@ -238,7 +244,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) + return "", ErrorInternalServerWrap(err, "error generating JWK thumbprint") } encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) return fmt.Sprintf("%s.%s", token, encPrint), nil @@ -246,9 +252,9 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { // storeError the given error to an ACME error and saves using the DB interface. func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { - ch.Error = err.ToACME() + ch.Error = err if err := db.UpdateChallenge(ctx, ch); err != nil { - return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge")) + return ErrorInternalServerWrap(err, "failure saving error to acme challenge") } return nil } diff --git a/acme/common.go b/acme/common.go index 1b268327..b9dc6ff2 100644 --- a/acme/common.go +++ b/acme/common.go @@ -6,7 +6,6 @@ import ( "net/url" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" ) @@ -96,7 +95,7 @@ const ( func AccountFromContext(ctx context.Context) (*Account, error) { val, ok := ctx.Value(AccContextKey).(*Account) if !ok || val == nil { - return nil, AccountDoesNotExistErr(nil) + return nil, NewError(ErrorServerInternalType, "account not in context") } return val, nil } @@ -114,7 +113,7 @@ func BaseURLFromContext(ctx context.Context) *url.URL { func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey) if !ok || val == nil { - return nil, ServerInternalErr(errors.Errorf("jwk expected in request context")) + return nil, NewError(ErrorServerInternalType, "jwk expected in request context") } return val, nil } @@ -123,7 +122,7 @@ func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature) if !ok || val == nil { - return nil, ServerInternalErr(errors.Errorf("jws expected in request context")) + return nil, NewError(ErrorServerInternalType, "jws expected in request context") } return val, nil } @@ -133,11 +132,11 @@ func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { val := ctx.Value(ProvisionerContextKey) if val == nil { - return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context")) + return nil, NewError(ErrorServerInternalType, "provisioner expected in request context") } pval, ok := val.(Provisioner) if !ok || pval == nil { - return nil, ServerInternalErr(errors.Errorf("provisioner in context is not an ACME provisioner")) + return nil, NewError(ErrorServerInternalType, "provisioner in context is not an ACME provisioner") } return pval, nil } diff --git a/acme/directory.go b/acme/directory.go index d5681b73..1b5b8c4b 100644 --- a/acme/directory.go +++ b/acme/directory.go @@ -5,8 +5,6 @@ import ( "encoding/json" "fmt" "net/url" - - "github.com/pkg/errors" ) // Directory represents an ACME directory for configuring clients. @@ -23,7 +21,7 @@ type Directory struct { func (d *Directory) ToLog() (interface{}, error) { b, err := json.Marshal(d) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging")) + return nil, ErrorInternalServerWrap(err, "error marshaling directory for logging") } return string(b), nil } diff --git a/acme/errors.go b/acme/errors.go index dc5b5568..aabc7302 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -93,8 +93,6 @@ func (ap ProblemType) String() string { return "externalAccountRequired" case ErrorInvalidContactType: return "incorrectResponse" - case ErrorInvalidContactType: - return "invalidContact" case ErrorMalformedType: return "malformed" case ErrorOrderNotReadyType: @@ -133,13 +131,11 @@ var ( officialACMEPrefix = "urn:ietf:params:acme:error:" stepACMEPrefix = "urn:step:acme:error:" errorServerInternalMetadata = errorMetadata{ - ErrorAccountDoesNotExistType: { - typ: officialACMEPrefix + ErrorServerInternalType.String(), - details: "The server experienced an internal error", - status: 500, - }, + typ: officialACMEPrefix + ErrorServerInternalType.String(), + details: "The server experienced an internal error", + status: 500, } - errorMap = [ProblemType]errorMetadata{ + errorMap = map[ProblemType]errorMetadata{ ErrorAccountDoesNotExistType: { typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(), details: "Account does not exist", @@ -267,7 +263,7 @@ var ( // Error represents an ACME type Error struct { Type string `json:"type"` - Detail string `json:"detail"` + Details string `json:"detail"` Subproblems []interface{} `json:"subproblems,omitempty"` Identifier interface{} `json:"identifier,omitempty"` Err error `json:"-"` @@ -275,13 +271,13 @@ type Error struct { } func NewError(pt ProblemType, msg string, args ...interface{}) *Error { - meta, ok := errorMetadata[typ] + meta, ok := errorMap[pt] if !ok { meta = errorServerInternalMetadata return &Error{ Type: meta.typ, Details: meta.details, - Status: meta.Status, + Status: meta.status, Err: errors.Errorf("unrecognized problemType %v", pt), } } @@ -301,7 +297,7 @@ func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Err return nil case *Error: if e.Err == nil { - e.Err = errors.Errorf(msg+"; "+e.Detail, args...) + e.Err = errors.Errorf(msg+"; "+e.Details, args...) } else { e.Err = errors.Wrapf(e.Err, msg, args...) } @@ -311,6 +307,11 @@ func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Err } } +// ErrorInternalServerWrap shortcut to wrap an internal server error type. +func ErrorInternalServerWrap(err error, msg string, args ...interface{}) *Error { + return ErrorWrap(ErrorServerInternalType, err, msg, args...) +} + // StatusCode returns the status code and implements the StatusCoder interface. func (e *Error) StatusCode() int { return e.Status @@ -318,13 +319,13 @@ func (e *Error) StatusCode() int { // Error allows AError to implement the error interface. func (e *Error) Error() string { - return e.Detail + return e.Details } // Cause returns the internal error and implements the Causer interface. func (e *Error) Cause() error { if e.Err == nil { - return errors.New(e.Detail) + return errors.New(e.Details) } return e.Err } diff --git a/acme/order.go b/acme/order.go index 01d3bc20..e0ac822b 100644 --- a/acme/order.go +++ b/acme/order.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/x509util" ) @@ -41,7 +40,7 @@ type Order struct { func (o *Order) ToLog() (interface{}, error) { b, err := json.Marshal(o) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging")) + return nil, ErrorInternalServerWrap(err, "error marshaling order for logging") } return string(b), nil } @@ -52,7 +51,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { now := time.Now().UTC() expiry, err := time.Parse(time.RFC3339, o.Expires) if err != nil { - return ServerInternalErr(errors.Wrap(err, "order.UpdateStatus - error converting expiry string to time")) + return ErrorInternalServerWrap(err, "order.UpdateStatus - error converting expiry string to time") } switch o.Status { @@ -64,7 +63,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { // Check expiry if now.After(expiry) { o.Status = StatusInvalid - o.Error = MalformedErr(errors.New("order has expired")) + o.Error = NewError(ErrorMalformedType, "order has expired") break } return nil @@ -72,7 +71,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { // Check expiry if now.After(expiry) { o.Status = StatusInvalid - o.Error = MalformedErr(errors.New("order has expired")) + o.Error = NewError(ErrorMalformedType, "order has expired") break } @@ -105,10 +104,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { o.Status = StatusReady default: - return ServerInternalErr(errors.New("unexpected authz status")) + return NewError(ErrorServerInternalType, "unexpected authz status") } default: - return ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) + return NewError(ErrorServerInternalType, "unrecognized order status: %s", o.Status) } return db.UpdateOrder(ctx, o) } @@ -122,15 +121,15 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques switch o.Status { case StatusInvalid: - return OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) + return NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID) case StatusValid: return nil case StatusPending: - return OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) + return NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID) case StatusReady: break default: - return ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) + return NewError(ErrorServerInternalType, "unexpected status %s for order %s", o.Status, o.ID) } // RFC8555: The CSR MUST indicate the exact same set of requested @@ -154,13 +153,15 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques // absence of other SANs as they will only be set if the templates allows // them. if len(csr.DNSNames) != len(orderNames) { - return BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) } sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) for i := range csr.DNSNames { if csr.DNSNames[i] != orderNames[i] { - return BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) } sans[i] = x509util.SubjectAlternativeName{ Type: x509util.DNSType, @@ -172,7 +173,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) + return ErrorInternalServerWrap(err, "error retrieving authorization options from ACME provisioner") } // Template data @@ -182,17 +183,17 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) + return ErrorInternalServerWrap(err, "error creating template options from ACME provisioner") } signOps = append(signOps, templateOptions) nbf, err := time.Parse(time.RFC3339, o.NotBefore) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error parsing order NotBefore")) + return ErrorInternalServerWrap(err, "error parsing order NotBefore") } naf, err := time.Parse(time.RFC3339, o.NotAfter) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error parsing order NotAfter")) + return ErrorInternalServerWrap(err, "error parsing order NotAfter") } // Sign a new certificate. @@ -201,7 +202,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques NotAfter: provisioner.NewTimeDuration(naf), }, signOps...) if err != nil { - return ServerInternalErr(errors.Wrapf(err, "error signing certificate for order %s", o.ID)) + return ErrorInternalServerWrap(err, "error signing certificate for order %s", o.ID) } cert := &Certificate{ From 1135ae04fc97fa21a3c47be211df21aff8c846fc Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 3 Mar 2021 15:16:25 -0800 Subject: [PATCH 09/47] [acme db interface] wip --- acme/api/account.go | 22 ++--- acme/api/order.go | 4 +- acme/authority.go | 169 +++++++++++++++++++++++++++++++------ acme/authorization.go | 22 ++--- acme/challenge.go | 14 +-- acme/db/nosql/authz.go | 44 ++-------- acme/db/nosql/challenge.go | 18 ---- acme/db/nosql/order.go | 44 +--------- acme/directory.go | 2 +- acme/errors.go | 63 +++++++------- acme/order.go | 68 ++++++--------- api/errors.go | 3 +- 12 files changed, 251 insertions(+), 222 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index ec2854cc..5e208a5f 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -21,7 +21,7 @@ type NewAccountRequest struct { func validateContacts(cs []string) error { for _, c := range cs { if len(c) == 0 { - return acme.MalformedErr(errors.New("contact cannot be empty string")) + return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string") } } return nil @@ -30,7 +30,7 @@ func validateContacts(cs []string) error { // Validate validates a new-account request body. func (n *NewAccountRequest) Validate() error { if n.OnlyReturnExisting && len(n.Contact) > 0 { - return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone")) + return acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone") } return validateContacts(n.Contact) } @@ -51,8 +51,8 @@ func (u *UpdateAccountRequest) IsDeactivateRequest() bool { func (u *UpdateAccountRequest) Validate() error { switch { case len(u.Status) > 0 && len(u.Contact) > 0: - return acme.MalformedErr(errors.New("incompatible input; contact and " + - "status updates are mutually exclusive")) + return acme.NewError(acme.ErrorMalformedType, "incompatible input; contact and "+ + "status updates are mutually exclusive") case len(u.Contact) > 0: if err := validateContacts(u.Contact); err != nil { return err @@ -60,8 +60,8 @@ func (u *UpdateAccountRequest) Validate() error { return nil case len(u.Status) > 0: if u.Status != string(acme.StatusDeactivated) { - return acme.MalformedErr(errors.Errorf("cannot update account "+ - "status to %s, only deactivated", u.Status)) + return acme.NewError(acme.ErrorMalformedType, "cannot update account "+ + "status to %s, only deactivated", u.Status) } return nil default: @@ -80,8 +80,8 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, - "failed to unmarshal new-account request payload"))) + api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, + "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { @@ -101,7 +101,8 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { // Account does not exist // if nar.OnlyReturnExisting { - api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + "account does not exist")) return } jwk, err := acme.JwkFromContext(r.Context()) @@ -146,7 +147,8 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload"))) + api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, + "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { diff --git a/acme/api/order.go b/acme/api/order.go index 5c62cb52..1fead85c 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -23,11 +23,11 @@ type NewOrderRequest struct { // Validate validates a new-order request body. func (n *NewOrderRequest) Validate() error { if len(n.Identifiers) == 0 { - return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")) + return acme.NewError(ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { if id.Type != "dns" { - return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type)) + return acme.NewError(ErrorMalformedType, "identifier type unsupported: %s", id.Type) } } return nil diff --git a/acme/authority.go b/acme/authority.go index 098c48d4..92e1c8f7 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -8,10 +8,12 @@ import ( "net" "net/http" "net/url" + "strings" "time" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" + "go.step.sm/crypto/randutil" ) // Interface is the acme authority interface. @@ -124,7 +126,7 @@ func (a *Authority) UseNonce(ctx context.Context, nonce string) error { // NewAccount creates, stores, and returns a new ACME account. func (a *Authority) NewAccount(ctx context.Context, acc *Account) error { if err := a.db.CreateAccount(ctx, acc); err != nil { - return ErrorWrap(ErrorServerInternalType, err, "error creating account") + return ErrorISEWrap(err, "error creating account") } return nil } @@ -136,7 +138,7 @@ func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, acc.Status = auo.Status */ if err := a.db.UpdateAccount(ctx, acc); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error updating account") + return nil, ErrorISEWrap(err, "error updating account") } return acc, nil } @@ -145,7 +147,7 @@ func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { acc, err := a.db.GetAccount(ctx, id) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving account") + return nil, ErrorISEWrap(err, "error retrieving account") } return acc, nil } @@ -168,7 +170,7 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order } o, err := a.db.GetOrder(ctx, orderID) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving order") + return nil, ErrorISEWrap(err, "error retrieving order") } if accID != o.AccountID { log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) @@ -179,7 +181,7 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") } if err = o.UpdateStatus(ctx, a.db); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error updating order") + return nil, ErrorISEWrap(err, "error updating order") } return o, nil } @@ -205,19 +207,54 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string */ // NewOrder generates, stores, and returns a new ACME order. -func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err +func (a *Authority) NewOrder(ctx context.Context, o *Order) error { + if len(o.AccountID) == 0 { + return NewErrorISE("account-id cannot be empty") + } + if len(o.ProvisionerID) == 0 { + return NewErrorISE("provisioner-id cannot be empty") + } + if len(o.Identifiers) == 0 { + return NewErrorISE("identifiers cannot be empty") + } + if o.DefaultDuration == 0 { + return NewErrorISE("default-duration cannot be empty") } - o.DefaultDuration = prov.DefaultTLSCertDuration() - o.Backdate = a.backdate.Duration - o.ProvisionerID = prov.GetID() - if err = a.db.CreateOrder(ctx, o); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error creating order") + o.AuthorizationIDs = make([]string, len(o.Identifiers)) + for i, identifier := range o.Identifiers { + az := &Authorization{ + AccountID: o.AccountID, + Identifier: identifier, + } + if err := a.NewAuthorization(ctx, az); err != nil { + return err + } + o.AuthorizationIDs[i] = az.ID } - return o, nil + + now := clock.Now() + if o.NotBefore.IsZero() { + o.NotBefore = now + } + if o.NotAfter.IsZero() { + o.NotAfter = o.NotBefore.Add(o.DefaultDuration) + } + + if err := a.db.CreateOrder(ctx, o); err != nil { + return ErrorISEWrap(err, "error creating order") + } + return nil + /* + o.DefaultDuration = prov.DefaultTLSCertDuration() + o.Backdate = a.backdate.Duration + o.ProvisionerID = prov.GetID() + + if err = a.db.CreateOrder(ctx, o); err != nil { + return nil, ErrorWrap(ErrorServerInternalType, err, "error creating order") + } + return o, nil + */ } // FinalizeOrder attempts to finalize an order and generate a new certificate. @@ -228,7 +265,7 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs } o, err := a.db.GetOrder(ctx, orderID) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving order") + return nil, ErrorISEWrap(err, "error retrieving order") } if accID != o.AccountID { log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) @@ -239,33 +276,113 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") } if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error finalizing order") + return nil, ErrorISEWrap(err, "error finalizing order") } return o, nil } -// GetAuthz retrieves and attempts to update the status on an ACME authz +// NewAuthorization generates and stores an ACME Authorization type along with +// any associated resources. +func (a *Authority) NewAuthorization(ctx context.Context, az *Authorization) error { + if len(az.AccountID) == 0 { + return NewErrorISE("account-id cannot be empty") + } + if len(az.Identifier.Value) == 0 { + return NewErrorISE("identifier cannot be empty") + } + + if strings.HasPrefix(az.Identifier.Value, "*.") { + az.Wildcard = true + az.Identifier = Identifier{ + Value: strings.TrimPrefix(az.Identifier.Value, "*."), + Type: az.Identifier.Type, + } + } + + var ( + err error + chTypes = []string{"dns-01"} + ) + // HTTP and TLS challenges can only be used for identifiers without wildcards. + if !az.Wildcard { + chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) + } + + az.Token, err = randutil.Alphanumeric(32) + if err != nil { + return ErrorISEWrap(err, "error generating random alphanumeric ID") + } + + az.Challenges = make([]*Challenge, len(chTypes)) + for i, typ := range chTypes { + ch := &Challenge{ + AccountID: az.AccountID, + AuthzID: az.ID, + Value: az.Identifier.Value, + Type: typ, + Token: az.Token, + } + if err := a.NewChallenge(ctx, ch); err != nil { + return err + } + az.Challenges[i] = ch + } + if err = a.db.CreateAuthorization(ctx, az); err != nil { + return ErrorISEWrap(err, "error creating authorization") + } + return nil +} + +// GetAuthorization retrieves and attempts to update the status on an ACME authz // before returning. -func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authorization, error) { +func (a *Authority) GetAuthorization(ctx context.Context, accID, authzID string) (*Authorization, error) { az, err := a.db.GetAuthorization(ctx, authzID) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving authorization") + return nil, ErrorISEWrap(err, "error retrieving authorization") } if accID != az.AccountID { log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) return nil, NewError(ErrorUnauthorizedType, "account does not own order") } if err = az.UpdateStatus(ctx, a.db); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error updating authorization status") + return nil, ErrorISEWrap(err, "error updating authorization status") } return az, nil } -// ValidateChallenge attempts to validate the challenge. -func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { +// NewChallenge generates and stores an ACME challenge and associated resources. +func (a *Authority) NewChallenge(ctx context.Context, ch *Challenge) error { + if len(ch.AccountID) == 0 { + return NewErrorISE("account-id cannot be empty") + } + if len(ch.AuthzID) == 0 { + return NewErrorISE("authz-id cannot be empty") + } + if len(ch.Token) == 0 { + return NewErrorISE("token cannot be empty") + } + if len(ch.Value) == 0 { + return NewErrorISE("value cannot be empty") + } + + switch ch.Type { + case "dns-01", "http-01", "tls-alpn-01": + break + default: + return NewErrorISE("unexpected error type '%s'", ch.Type) + } + + if err := a.db.CreateChallenge(ctx, ch); err != nil { + return ErrorISEWrap(err, "error creating challenge") + } + return nil +} + +// GetValidateChallenge attempts to validate the challenge. +func (a *Authority) GetValidateChallenge(ctx context.Context, accID, chID, azID string, jwk *jose.JSONWebKey) (*Challenge, error) { ch, err := a.db.GetChallenge(ctx, chID, "todo") if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving challenge") + return nil, ErrorISEWrap(err, "error retrieving challenge") } if accID != ch.AccountID { log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, ch.AccountID) @@ -284,7 +401,7 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j return tls.DialWithDialer(dialer, network, addr, config) }, }); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error validating challenge") + return nil, ErrorISEWrap(err, "error validating challenge") } return ch, nil } @@ -293,7 +410,7 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { cert, err := a.db.GetCertificate(ctx, certID) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error retrieving certificate") + return nil, ErrorISEWrap(err, "error retrieving certificate") } if cert.AccountID != accID { log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) diff --git a/acme/authorization.go b/acme/authorization.go index ef230286..e4bc669d 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -8,20 +8,22 @@ import ( // Authorization representst an ACME Authorization. type Authorization struct { - Identifier *Identifier `json:"identifier"` - Status Status `json:"status"` - Expires string `json:"expires"` - Challenges []*Challenge `json:"challenges"` - Wildcard bool `json:"wildcard"` - ID string `json:"-"` - AccountID string `json:"-"` + Identifier Identifier `json:"identifier"` + Status Status `json:"status"` + Expires string `json:"expires"` + Challenges []*Challenge `json:"challenges"` + ChallengeIDs string `json::"-"` + Wildcard bool `json:"wildcard"` + ID string `json:"-"` + AccountID string `json:"-"` + Token string `json:"-"` } // ToLog enables response logging. func (az *Authorization) ToLog() (interface{}, error) { b, err := json.Marshal(az) if err != nil { - return nil, ErrorInternalServerWrap(err, "error marshaling authz for logging") + return nil, ErrorISEWrap(err, "error marshaling authz for logging") } return string(b), nil } @@ -32,7 +34,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { now := time.Now().UTC() expiry, err := time.Parse(time.RFC3339, az.Expires) if err != nil { - return ErrorInternalServerWrap(err, "error converting expiry string to time") + return ErrorISEWrap(err, "error converting expiry string to time") } switch az.Status { @@ -64,7 +66,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { } if err = db.UpdateAuthorization(ctx, az); err != nil { - return ErrorInternalServerWrap(err, "error updating authorization") + return ErrorISEWrap(err, "error updating authorization") } return nil } diff --git a/acme/challenge.go b/acme/challenge.go index 05987427..ca2e5562 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -38,7 +38,7 @@ type Challenge struct { func (ch *Challenge) ToLog() (interface{}, error) { b, err := json.Marshal(ch) if err != nil { - return nil, ErrorInternalServerWrap(err, "error marshaling challenge for logging") + return nil, ErrorISEWrap(err, "error marshaling challenge for logging") } return string(b), nil } @@ -80,7 +80,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb body, err := ioutil.ReadAll(resp.Body) if err != nil { - return ErrorInternalServerWrap(err, "error reading "+ + return ErrorISEWrap(err, "error reading "+ "response body for url %s", url) } keyAuth := strings.Trim(string(body), "\r\n") @@ -100,7 +100,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorInternalServerWrap(err, "error updating challenge") + return ErrorISEWrap(err, "error updating challenge") } return nil } @@ -178,7 +178,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorInternalServerWrap(err, "tlsalpn01ValidateChallenge - error updating challenge") + return ErrorISEWrap(err, "tlsalpn01ValidateChallenge - error updating challenge") } return nil } @@ -234,7 +234,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK ch.Validated = clock.Now().UTC().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorInternalServerWrap(err, "error updating challenge") + return ErrorISEWrap(err, "error updating challenge") } return nil } @@ -244,7 +244,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return "", ErrorInternalServerWrap(err, "error generating JWK thumbprint") + return "", ErrorISEWrap(err, "error generating JWK thumbprint") } encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) return fmt.Sprintf("%s.%s", token, encPrint), nil @@ -254,7 +254,7 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { ch.Error = err if err := db.UpdateChallenge(ctx, ch); err != nil { - return ErrorInternalServerWrap(err, "failure saving error to acme challenge") + return ErrorISEWrap(err, "failure saving error to acme challenge") } return nil } diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index bc9f75bc..818f5c2d 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -3,7 +3,6 @@ package nosql import ( "context" "encoding/json" - "strings" "time" "github.com/pkg/errors" @@ -75,18 +74,17 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat // CreateAuthorization creates an entry in the database for the Authorization. // Implements the acme.DB.CreateAuthorization interface. func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error { - if len(az.AccountID) == 0 { - return errors.New("account-id cannot be empty") - } - if az.Identifier == nil { - return errors.New("identifier cannot be nil") - } var err error az.ID, err = randID() if err != nil { return err } + chIDs := make([]string, len(az.Challenges)) + for i, ch := range az.Challenges { + chIDs[i] = ch.ID + } + now := clock.Now() dbaz := &dbAuthz{ ID: az.ID, @@ -95,38 +93,10 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e Created: now, Expires: now.Add(defaultExpiryDuration), Identifier: az.Identifier, + Challenges: chIDs, + Wildcard: az.Wildcard, } - if strings.HasPrefix(az.Identifier.Value, "*.") { - dbaz.Wildcard = true - dbaz.Identifier = &acme.Identifier{ - Value: strings.TrimPrefix(az.Identifier.Value, "*."), - Type: az.Identifier.Type, - } - } - - chIDs := []string{} - chTypes := []string{"dns-01"} - // HTTP and TLS challenges can only be used for identifiers without wildcards. - if !dbaz.Wildcard { - chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) - } - - for _, typ := range chTypes { - ch := &acme.Challenge{ - AccountID: az.AccountID, - AuthzID: az.ID, - Value: az.Identifier.Value, - Type: typ, - } - if err = db.CreateChallenge(ctx, ch); err != nil { - return errors.Wrapf(err, "error creating challenge") - } - - chIDs = append(chIDs, ch.ID) - } - dbaz.Challenges = chIDs - return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable) } diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index 62513778..378b1f7b 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -47,29 +47,11 @@ func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, erro // CreateChallenge creates a new ACME challenge data structure in the database. // Implements acme.DB.CreateChallenge interface. func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error { - if len(ch.AuthzID) == 0 { - return errors.New("AuthzID cannot be empty") - } - if len(ch.AccountID) == 0 { - return errors.New("AccountID cannot be empty") - } - if len(ch.Value) == 0 { - return errors.New("AccountID cannot be empty") - } - // TODO: verify that challenge type is set and is one of expected types. - if len(ch.Type) == 0 { - return errors.New("Type cannot be empty") - } - var err error ch.ID, err = randID() if err != nil { return errors.Wrap(err, "error generating random id for ACME challenge") } - ch.Token, err = randID() - if err != nil { - return errors.Wrap(err, "error generating token for ACME challenge") - } dbch := &dbChallenge{ ID: ch.ID, diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 528619d4..d2146e22 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -74,48 +74,12 @@ func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { // CreateOrder creates ACME Order resources and saves them to the DB. func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { - if len(o.AccountID) == 0 { - return ServerInternalErr(errors.New("account-id cannot be empty")) - } - if len(o.ProvisionerID) == 0 { - return ServerInternalErr(errors.New("provisioner-id cannot be empty")) - } - if len(o.Identifiers) == 0 { - return ServerInternalErr(errors.New("identifiers cannot be empty")) - } - if o.DefaultDuration == 0 { - return ServerInternalErr(errors.New("default-duration cannot be empty")) - } - o.ID, err = randID() if err != nil { return nil, err } - azIDs := make([]string, len(ops.Identifiers)) - for i, identifier := range ops.Identifiers { - az, err = db.CreateAuthorzation(&types.Authorization{ - AccountID: o.AccountID, - Identifier: o.Identifier, - }) - if err != nil { - return err - } - azIDs[i] = az.ID - } - now := clock.Now() - var backdate time.Duration - nbf := o.NotBefore - if nbf.IsZero() { - nbf = now - backdate = -1 * o.Backdate - } - naf := o.NotAfter - if naf.IsZero() { - naf = nbf.Add(o.DefaultDuration) - } - dbo := &dbOrder{ ID: o.ID, AccountID: o.AccountID, @@ -123,10 +87,10 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { Created: now, Status: StatusPending, Expires: now.Add(defaultOrderExpiry), - Identifiers: ops.Identifiers, - NotBefore: nbf.Add(backdate), - NotAfter: naf, - Authorizations: azIDs, + Identifiers: o.Identifiers, + NotBefore: o.NotBefore, + NotAfter: o.NotBefore, + Authorizations: o.AuthorizationIDs, } if err := db.save(ctx, o.ID, dbo, nil, orderTable); err != nil { return nil, err diff --git a/acme/directory.go b/acme/directory.go index 1b5b8c4b..8520d0e9 100644 --- a/acme/directory.go +++ b/acme/directory.go @@ -21,7 +21,7 @@ type Directory struct { func (d *Directory) ToLog() (interface{}, error) { b, err := json.Marshal(d) if err != nil { - return nil, ErrorInternalServerWrap(err, "error marshaling directory for logging") + return nil, ErrorISEWrap(err, "error marshaling directory for logging") } return string(b), nil } diff --git a/acme/errors.go b/acme/errors.go index aabc7302..8fe2559d 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -1,4 +1,3 @@ -// Error represents an ACME package acme import ( @@ -11,55 +10,55 @@ import ( type ProblemType int const ( - // The request specified an account that does not exist + // ErrorAccountDoesNotExistType request specified an account that does not exist ErrorAccountDoesNotExistType ProblemType = iota - // The request specified a certificate to be revoked that has already been revoked + // ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked ErrorAlreadyRevokedType - // The CSR is unacceptable (e.g., due to a short key) + // ErrorBadCSRType CSR is unacceptable (e.g., due to a short key) ErrorBadCSRType - // The client sent an unacceptable anti-replay nonce + // ErrorBadNonceType client sent an unacceptable anti-replay nonce ErrorBadNonceType - // The JWS was signed by a public key the server does not support + // ErrorBadPublicKeyType JWS was signed by a public key the server does not support ErrorBadPublicKeyType - // The revocation reason provided is not allowed by the server + // ErrorBadRevocationReasonType revocation reason provided is not allowed by the server ErrorBadRevocationReasonType - // The JWS was signed with an algorithm the server does not support + // ErrorBadSignatureAlgorithmType JWS was signed with an algorithm the server does not support ErrorBadSignatureAlgorithmType - // Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate + // ErrorCaaType Authority Authorization (CAA) records forbid the CA from issuing a certificate ErrorCaaType - // Specific error conditions are indicated in the “subproblems” array. + // ErrorCompoundType error conditions are indicated in the “subproblems” array. ErrorCompoundType - // The server could not connect to validation target + // ErrorConnectionType server could not connect to validation target ErrorConnectionType - // There was a problem with a DNS query during identifier validation + // ErrorDNSType was a problem with a DNS query during identifier validation ErrorDNSType - // The request must include a value for the “externalAccountBinding” field + // ErrorExternalAccountRequiredType request must include a value for the “externalAccountBinding” field ErrorExternalAccountRequiredType - // Response received didn’t match the challenge’s requirements + // ErrorIncorrectResponseType received didn’t match the challenge’s requirements ErrorIncorrectResponseType - // A contact URL for an account was invalid + // ErrorInvalidContactType URL for an account was invalid ErrorInvalidContactType - // The request message was malformed + // ErrorMalformedType request message was malformed ErrorMalformedType - // The request attempted to finalize an order that is not ready to be finalized + // ErrorOrderNotReadyType request attempted to finalize an order that is not ready to be finalized ErrorOrderNotReadyType - // The request exceeds a rate limit + // ErrorRateLimitedType request exceeds a rate limit ErrorRateLimitedType - // The server will not issue certificates for the identifier + // ErrorRejectedIdentifierType server will not issue certificates for the identifier ErrorRejectedIdentifierType - // The server experienced an internal error + // ErrorServerInternalType server experienced an internal error ErrorServerInternalType - // The server received a TLS error during validation + // ErrorTLSType server received a TLS error during validation ErrorTLSType - // The client lacks sufficient authorization + // ErrorUnauthorizedType client lacks sufficient authorization ErrorUnauthorizedType - // A contact URL for an account used an unsupported protocol scheme + // ErrorUnsupportedContactType URL for an account used an unsupported protocol scheme ErrorUnsupportedContactType - // An identifier is of an unsupported type + // ErrorUnsupportedIdentifierType identifier is of an unsupported type ErrorUnsupportedIdentifierType - // Visit the “instance” URL and take actions specified there + // ErrorUserActionRequiredType the “instance” URL and take actions specified there ErrorUserActionRequiredType - // The operation is not implemented + // ErrorNotImplementedType operation is not implemented ErrorNotImplementedType ) @@ -116,7 +115,7 @@ func (ap ProblemType) String() string { case ErrorNotImplementedType: return "notImplemented" default: - return fmt.Sprintf("unsupported type ACME error type %v", ap) + return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap)) } } @@ -270,6 +269,7 @@ type Error struct { Status int `json:"-"` } +// NewError creates a new Error type. func NewError(pt ProblemType, msg string, args ...interface{}) *Error { meta, ok := errorMap[pt] if !ok { @@ -290,6 +290,11 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error { } } +// NewErrorISE creates a new ErrorServerInternalType Error. +func NewErrorISE(msg string, args ...interface{}) *Error { + return NewError(ErrorServerInternalType, msg, args...) +} + // ErrorWrap attempts to wrap the internal error. func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Error { switch e := err.(type) { @@ -307,8 +312,8 @@ func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Err } } -// ErrorInternalServerWrap shortcut to wrap an internal server error type. -func ErrorInternalServerWrap(err error, msg string, args ...interface{}) *Error { +// ErrorISEWrap shortcut to wrap an internal server error type. +func ErrorISEWrap(err error, msg string, args ...interface{}) *Error { return ErrorWrap(ErrorServerInternalType, err, msg, args...) } diff --git a/acme/order.go b/acme/order.go index e0ac822b..bf3297f9 100644 --- a/acme/order.go +++ b/acme/order.go @@ -20,27 +20,28 @@ type Identifier struct { // Order contains order metadata for the ACME protocol order type. type Order struct { - Status Status `json:"status"` - Expires string `json:"expires,omitempty"` - Identifiers []Identifier `json:"identifiers"` - NotBefore string `json:"notBefore,omitempty"` - NotAfter string `json:"notAfter,omitempty"` - Error interface{} `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - FinalizeURL string `json:"finalize"` - Certificate string `json:"certificate,omitempty"` - ID string `json:"-"` - AccountID string `json:"-"` - ProvisionerID string `json:"-"` - DefaultDuration time.Duration `json:"-"` - Backdate time.Duration `json:"-"` + Status Status `json:"status"` + Expires time.Time `json:"expires,omitempty"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + Error interface{} `json:"error,omitempty"` + AuthorizationURLs []string `json:"authorizations"` + AuthorizationIDs []string `json:"-"` + FinalizeURL string `json:"finalize"` + Certificate string `json:"certificate,omitempty"` + ID string `json:"-"` + AccountID string `json:"-"` + ProvisionerID string `json:"-"` + DefaultDuration time.Duration `json:"-"` + Backdate time.Duration `json:"-"` } // ToLog enables response logging. func (o *Order) ToLog() (interface{}, error) { b, err := json.Marshal(o) if err != nil { - return nil, ErrorInternalServerWrap(err, "error marshaling order for logging") + return nil, ErrorISEWrap(err, "error marshaling order for logging") } return string(b), nil } @@ -49,10 +50,6 @@ func (o *Order) ToLog() (interface{}, error) { // Changes to the order are saved using the database interface. func (o *Order) UpdateStatus(ctx context.Context, db DB) error { now := time.Now().UTC() - expiry, err := time.Parse(time.RFC3339, o.Expires) - if err != nil { - return ErrorInternalServerWrap(err, "order.UpdateStatus - error converting expiry string to time") - } switch o.Status { case StatusInvalid: @@ -61,7 +58,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusReady: // Check expiry - if now.After(expiry) { + if now.After(o.Expires) { o.Status = StatusInvalid o.Error = NewError(ErrorMalformedType, "order has expired") break @@ -69,7 +66,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusPending: // Check expiry - if now.After(expiry) { + if now.After(o.Expires) { o.Status = StatusInvalid o.Error = NewError(ErrorMalformedType, "order has expired") break @@ -80,7 +77,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { StatusInvalid: 0, StatusPending: 0, } - for _, azID := range o.Authorizations { + for _, azID := range o.AuthorizationIDs { az, err := db.GetAuthorization(ctx, azID) if err != nil { return err @@ -100,14 +97,14 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { case count[StatusPending] > 0: return nil - case count[StatusValid] == len(o.Authorizations): + case count[StatusValid] == len(o.AuthorizationIDs): o.Status = StatusReady default: - return NewError(ErrorServerInternalType, "unexpected authz status") + return NewErrorISE("unexpected authz status") } default: - return NewError(ErrorServerInternalType, "unrecognized order status: %s", o.Status) + return NewErrorISE("unrecognized order status: %s", o.Status) } return db.UpdateOrder(ctx, o) } @@ -129,7 +126,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques case StatusReady: break default: - return NewError(ErrorServerInternalType, "unexpected status %s for order %s", o.Status, o.ID) + return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID) } // RFC8555: The CSR MUST indicate the exact same set of requested @@ -173,7 +170,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return ErrorInternalServerWrap(err, "error retrieving authorization options from ACME provisioner") + return ErrorISEWrap(err, "error retrieving authorization options from ACME provisioner") } // Template data @@ -183,26 +180,17 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return ErrorInternalServerWrap(err, "error creating template options from ACME provisioner") + return ErrorISEWrap(err, "error creating template options from ACME provisioner") } signOps = append(signOps, templateOptions) - nbf, err := time.Parse(time.RFC3339, o.NotBefore) - if err != nil { - return ErrorInternalServerWrap(err, "error parsing order NotBefore") - } - naf, err := time.Parse(time.RFC3339, o.NotAfter) - if err != nil { - return ErrorInternalServerWrap(err, "error parsing order NotAfter") - } - // Sign a new certificate. certChain, err := auth.Sign(csr, provisioner.SignOptions{ - NotBefore: provisioner.NewTimeDuration(nbf), - NotAfter: provisioner.NewTimeDuration(naf), + NotBefore: provisioner.NewTimeDuration(o.NotBefore), + NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) if err != nil { - return ErrorInternalServerWrap(err, "error signing certificate for order %s", o.ID) + return ErrorISEWrap(err, "error signing certificate for order %s", o.ID) } cert := &Certificate{ diff --git a/api/errors.go b/api/errors.go index 93057ed2..460192fc 100644 --- a/api/errors.go +++ b/api/errors.go @@ -14,10 +14,9 @@ import ( // WriteError writes to w a JSON representation of the given error. func WriteError(w http.ResponseWriter, err error) { - switch k := err.(type) { + switch err.(type) { case *acme.Error: w.Header().Set("Content-Type", "application/problem+json") - err = k.ToACME() default: w.Header().Set("Content-Type", "application/json") } From 491c188a5e5f8406ba3e2ce799c85a69dc2624e0 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 4 Mar 2021 13:18:41 -0800 Subject: [PATCH 10/47] [acme db interface] wip --- acme/authorization.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/acme/authorization.go b/acme/authorization.go index e4bc669d..7f15f4c6 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -8,15 +8,14 @@ import ( // Authorization representst an ACME Authorization. type Authorization struct { - Identifier Identifier `json:"identifier"` - Status Status `json:"status"` - Expires string `json:"expires"` - Challenges []*Challenge `json:"challenges"` - ChallengeIDs string `json::"-"` - Wildcard bool `json:"wildcard"` - ID string `json:"-"` - AccountID string `json:"-"` - Token string `json:"-"` + Identifier Identifier `json:"identifier"` + Status Status `json:"status"` + Expires time.Time `json:"expires"` + Challenges []*Challenge `json:"challenges"` + Wildcard bool `json:"wildcard"` + ID string `json:"-"` + AccountID string `json:"-"` + Token string `json:"-"` } // ToLog enables response logging. From 80a6640103132db93fa8d45edf4c3f11a2f2533b Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 4 Mar 2021 23:10:46 -0800 Subject: [PATCH 11/47] [acme db interface] wip --- acme/account.go | 4 +- acme/api/account.go | 92 ++++---- acme/api/handler.go | 220 +++++++++++++------ acme/api/linker.go | 164 +++++++++++++++ acme/api/middleware.go | 195 ++++++++++++----- acme/api/order.go | 170 ++++++++++++--- acme/authority.go | 420 ------------------------------------- acme/authorization.go | 16 +- acme/certificate.go | 3 +- acme/challenge.go | 45 ++-- acme/common.go | 104 ++------- acme/db/nosql/account.go | 1 - acme/db/nosql/authz.go | 20 +- acme/db/nosql/challenge.go | 3 +- acme/db/nosql/order.go | 73 +++---- acme/directory.go | 148 ------------- acme/errors.go | 34 +-- acme/nonce.go | 6 + acme/order.go | 17 +- ca/acmeClient.go | 12 +- ca/ca.go | 7 +- 21 files changed, 783 insertions(+), 971 deletions(-) create mode 100644 acme/api/linker.go delete mode 100644 acme/authority.go delete mode 100644 acme/directory.go diff --git a/acme/account.go b/acme/account.go index 80cc66ef..354ebdc7 100644 --- a/acme/account.go +++ b/acme/account.go @@ -22,7 +22,7 @@ type Account struct { func (a *Account) ToLog() (interface{}, error) { b, err := json.Marshal(a) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error marshaling account for logging") + return nil, WrapErrorISE(err, "error marshaling account for logging") } return string(b), nil } @@ -46,7 +46,7 @@ func (a *Account) IsValid() bool { func KeyToID(jwk *jose.JSONWebKey) (string, error) { kid, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return "", ErrorWrap(ErrorServerInternalType, err, "error generating jwk thumbprint") + return "", WrapErrorISE(err, "error generating jwk thumbprint") } return base64.RawURLEncoding.EncodeToString(kid), nil } diff --git a/acme/api/account.go b/acme/api/account.go index 5e208a5f..16cc1f79 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -4,8 +4,6 @@ import ( "encoding/json" "net/http" - "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/logging" @@ -37,14 +35,8 @@ func (n *NewAccountRequest) Validate() error { // UpdateAccountRequest represents an update-account request. type UpdateAccountRequest struct { - Contact []string `json:"contact"` - Status string `json:"status"` -} - -// IsDeactivateRequest returns true if the update request is a deactivation -// request, false otherwise. -func (u *UpdateAccountRequest) IsDeactivateRequest() bool { - return u.Status == string(acme.StatusDeactivated) + Contact []string `json:"contact"` + Status acme.Status `json:"status"` } // Validate validates a update-account request body. @@ -59,7 +51,7 @@ func (u *UpdateAccountRequest) Validate() error { } return nil case len(u.Status) > 0: - if u.Status != string(acme.StatusDeactivated) { + if u.Status != acme.StatusDeactivated { return acme.NewError(acme.ErrorMalformedType, "cannot update account "+ "status to %s, only deactivated", u.Status) } @@ -80,7 +72,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } @@ -90,7 +82,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } httpStatus := http.StatusCreated - acc, err := acme.AccountFromContext(r.Context()) + acc, err := accountFromContext(r.Context()) if err != nil { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { @@ -105,18 +97,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { "account does not exist")) return } - jwk, err := acme.JwkFromContext(r.Context()) + jwk, err := jwkFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - if acc, err = h.Auth.NewAccount(r.Context(), &acme.Account{ + acc := &acme.Account{ Key: jwk, Contact: nar.Contact, Status: acme.StatusValid, - }); err != nil { - api.WriteError(w, err) + } + if err := h.db.CreateAccount(r.Context(), acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) return } } else { @@ -124,14 +117,16 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, - true, acc.GetID())) + h.linker.LinkAccount(ctx, acc) + + w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, + true, acc.ID)) api.JSONStatus(w, acc, httpStatus) } // GetUpdateAccount is the api for updating an ACME account. func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + acc, err := accountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -147,7 +142,7 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } @@ -159,18 +154,18 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { // If neither the status nor the contacts are being updated then ignore // the updates and return 200. This conforms with the behavior detailed // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). - if uar.IsDeactivateRequest() { - acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID()) - } else if len(uar.Contact) > 0 { - acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact) - } - if err != nil { - api.WriteError(w, err) + acc.Status = uar.Status + acc.Contact = uar.Contact + if err = h.db.UpdateAccount(r.Context(), acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) return } } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, - true, acc.GetID())) + + h.linker.LinkAccount(ctx, acc) + + w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, + true, acc.ID)) api.JSON(w, acc) } @@ -185,21 +180,24 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { // GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) - if err != nil { - api.WriteError(w, err) - return - } - accID := chi.URLParam(r, "accID") - if acc.ID != accID { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) - return - } - orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) - if err != nil { - api.WriteError(w, err) - return - } - api.JSON(w, orders) - logOrdersByAccount(w, orders) + /* + acc, err := acme.AccountFromContext(r.Context()) + if err != nil { + api.WriteError(w, err) + return + } + accID := chi.URLParam(r, "accID") + if acc.ID != accID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param")) + return + } + orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + api.JSON(w, orders) + logOrdersByAccount(w, orders) + */ + return } diff --git a/acme/api/handler.go b/acme/api/handler.go index 921e614e..997456a7 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,56 +1,82 @@ package api import ( - "context" - "crypto/x509" - "encoding/pem" + "crypto/tls" + "encoding/json" "fmt" + "net" "net/http" + "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" ) func link(url, typ string) string { return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) } +// Clock that returns time in UTC rounded to seconds. +type Clock int + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Round(time.Second) +} + +var clock = new(Clock) + type payloadInfo struct { value []byte isPostAsGet bool isEmptyJSON bool } -// payloadFromContext searches the context for a payload. Returns the payload -// or an error. -func payloadFromContext(ctx context.Context) (*payloadInfo, error) { - val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo) - if !ok || val == nil { - return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context")) - } - return val, nil +// Handler is the ACME API request handler. +type Handler struct { + db acme.DB + backdate provisioner.Duration + ca acme.CertificateAuthority + linker *Linker } -// New returns a new ACME API router. -func New(acmeAuth acme.Interface) api.RouterHandler { - return &Handler{acmeAuth} +// HandlerOptions required to create a new ACME API request handler. +type HandlerOptions struct { + Backdate provisioner.Duration + // DB storage backend that impements the acme.DB interface. + DB acme.DB + // DNS the host used to generate accurate ACME links. By default the authority + // will use the Host from the request, so this value will only be used if + // request.Host is empty. + DNS string + // Prefix is a URL path prefix under which the ACME api is served. This + // prefix is required to generate accurate ACME links. + // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- + // "acme" is the prefix from which the ACME api is accessed. + Prefix string + CA acme.CertificateAuthority } -// Handler is the ACME request handler. -type Handler struct { - Auth acme.Interface +// NewHandler returns a new ACME API handler. +func NewHandler(ops HandlerOptions) api.RouterHandler { + return &Handler{ + ca: ops.CA, + db: ops.DB, + backdate: ops.Backdate, + linker: NewLinker(ops.DNS, ops.Prefix), + } } // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { - getLink := h.Auth.GetLinkExplicit + getLink := h.linker.GetLinkExplicit // Standard ACME API - r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) - r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) extractPayloadByJWK := func(next nextHTTP) nextHTTP { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) @@ -59,16 +85,16 @@ func (h *Handler) Route(r api.Router) { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) } - r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) - r.MethodFunc("POST", getLink(acme.KeyChangeLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) - r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) - r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) + r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) + r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) + r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) + r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) + r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) + r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) + r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) + r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) + r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) + r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } // GetNonce just sets the right header since a Nonce is added to each response @@ -81,101 +107,165 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { } } +// Directory represents an ACME directory for configuring clients. +type Directory struct { + NewNonce string `json:"newNonce,omitempty"` + NewAccount string `json:"newAccount,omitempty"` + NewOrder string `json:"newOrder,omitempty"` + NewAuthz string `json:"newAuthz,omitempty"` + RevokeCert string `json:"revokeCert,omitempty"` + KeyChange string `json:"keyChange,omitempty"` +} + +// ToLog enables response logging for the Directory type. +func (d *Directory) ToLog() (interface{}, error) { + b, err := json.Marshal(d) + if err != nil { + return nil, acme.WrapErrorISE(err, "error marshaling directory for logging") + } + return string(b), nil +} + +type directory struct { + prefix, dns string +} + // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { - dir, err := h.Auth.GetDirectory(r.Context()) - if err != nil { - api.WriteError(w, err) - return - } - api.JSON(w, dir) + ctx := r.Context() + api.JSON(w, &Directory{ + NewNonce: h.linker.GetLink(ctx, NewNonceLinkType, true), + NewAccount: h.linker.GetLink(ctx, NewAccountLinkType, true), + NewOrder: h.linker.GetLink(ctx, NewOrderLinkType, true), + RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType, true), + KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType, true), + }) } // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NotImplemented(nil).ToACME()) + api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthz ACME api for retrieving an Authz. func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID")) + az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + return + } + if acc.ID != az.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } + if err = az.UpdateStatus(ctx, h.db); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + } + + h.linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID())) - api.JSON(w, authz) + w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, true, az.ID)) + api.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. - _, err = payloadFromContext(r.Context()) + _, err = payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return } - // NOTE: We should be checking that the request is either a POST-as-GET, or + // NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or // that the payload is an empty JSON block ({}). However, older ACME clients // still send a vestigial body (rather than an empty JSON block) and // strict enforcement would render these clients broken. For the time being // we'll just ignore the body. - var ( - ch *acme.Challenge - chID = chi.URLParam(r, "chID") - ) - ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey()) + + ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), chi.URLParam(r, "authzID")) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + return + } + if acc.ID != ch.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) + return + } + client := http.Client{ + Timeout: time.Duration(30 * time.Second), + } + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + } + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } + if err = ch.Validate(ctx, h.db, jwk, acme.ValidateOptions{ + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, config) + }, + }); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + return + } - w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up")) - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID())) + h.linker.LinkChallenge(ctx, ch) + + w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, ch.AuthzID), "up")) + w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID)) api.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } certID := chi.URLParam(r, "certID") - certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID) + + cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) return } - - block, _ := pem.Decode(certBytes) - if block == nil { - api.WriteError(w, acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes"))) + if cert.AccountID != acc.ID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own certificate '%s'", acc.ID, certID)) return } - cert, err := x509.ParseCertificate(block.Bytes) + + certBytes, err := cert.ToACME() if err != nil { - api.WriteError(w, acme.Wrap(err, "failed to parse generated leaf certificate")) + api.WriteError(w, acme.WrapErrorISE(err, "error converting cert to ACME representation")) return } - api.LogCertificate(w, cert) + api.LogCertificate(w, cert.Leaf) w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8") w.Write(certBytes) } diff --git a/acme/api/linker.go b/acme/api/linker.go new file mode 100644 index 00000000..dd3b4540 --- /dev/null +++ b/acme/api/linker.go @@ -0,0 +1,164 @@ +package api + +import ( + "context" + "fmt" + "net/url" + + "github.com/smallstep/certificates/acme" +) + +// NewLinker returns a new Directory type. +func NewLinker(dns, prefix string) *Linker { + return &Linker{Prefix: prefix, DNS: dns} +} + +// Linker generates ACME links. +type Linker struct { + Prefix string + DNS string +} + +// GetLink is a helper for GetLinkExplicit +func (l *Linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { + var provName string + if p, err := provisionerFromContext(ctx); err == nil && p != nil { + provName = p.GetName() + } + return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...) +} + +// GetLinkExplicit returns an absolute or partial path to the given resource and a base +// URL dynamically obtained from the request for which the link is being +// calculated. +func (l *Linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { + var link string + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + link = fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + link = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + } + + if abs { + // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 + u := url.URL{} + if baseURL != nil { + u = *baseURL + } + + // If no Scheme is set, then default to https. + if u.Scheme == "" { + u.Scheme = "https" + } + + // If no Host is set, then use the default (first DNS attr in the ca.json). + if u.Host == "" { + u.Host = l.DNS + } + + u.Path = l.Prefix + link + return u.String() + } + return link +} + +// LinkType captures the link type. +type LinkType int + +const ( + // NewNonceLinkType new-nonce + NewNonceLinkType LinkType = iota + // NewAccountLinkType new-account + NewAccountLinkType + // AccountLinkType account + AccountLinkType + // OrderLinkType order + OrderLinkType + // NewOrderLinkType new-order + NewOrderLinkType + // OrdersByAccountLinkType list of orders owned by account + OrdersByAccountLinkType + // FinalizeLinkType finalize order + FinalizeLinkType + // NewAuthzLinkType authz + NewAuthzLinkType + // AuthzLinkType new-authz + AuthzLinkType + // ChallengeLinkType challenge + ChallengeLinkType + // CertificateLinkType certificate + CertificateLinkType + // DirectoryLinkType directory + DirectoryLinkType + // RevokeCertLinkType revoke certificate + RevokeCertLinkType + // KeyChangeLinkType key rollover + KeyChangeLinkType +) + +func (l LinkType) String() string { + switch l { + case NewNonceLinkType: + return "new-nonce" + case NewAccountLinkType: + return "new-account" + case AccountLinkType: + return "account" + case NewOrderLinkType: + return "new-order" + case OrderLinkType: + return "order" + case NewAuthzLinkType: + return "new-authz" + case AuthzLinkType: + return "authz" + case ChallengeLinkType: + return "challenge" + case CertificateLinkType: + return "certificate" + case DirectoryLinkType: + return "directory" + case RevokeCertLinkType: + return "revoke-cert" + case KeyChangeLinkType: + return "key-change" + default: + return fmt.Sprintf("unexpected LinkType '%d'", int(l)) + } +} + +// LinkOrder sets the ACME links required by an ACME order. +func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) { + o.azURLs = make([]string, len(o.AuthorizationIDs)) + for i, azID := range o.AutohrizationIDs { + o.azURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) + } + o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID) + if o.CertificateID != "" { + o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID) + } +} + +// LinkAccount sets the ACME links required by an ACME account. +func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) { + a.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) +} + +// LinkChallenge sets the ACME links required by an ACME account. +func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { + a.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) +} + +// LinkAuthorization sets the ACME links required by an ACME account. +func (l *Linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { + for _, ch := range az.Challenges { + l.LinkChallenge(ctx, ch) + } +} diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 3bf5f89a..7a3529cd 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" @@ -54,7 +53,7 @@ func baseURLFromRequest(r *http.Request) *url.URL { // E.g. https://ca.smallstep.com/ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r)) + ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) next(w, r.WithContext(ctx)) } } @@ -62,14 +61,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { // addNonce is a middleware that adds a nonce to the response header. func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.Auth.NewNonce() + nonce, err := h.db.CreateNonce(r.Context()) if err != nil { api.WriteError(w, err) return } - w.Header().Set("Replay-Nonce", nonce) + w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Cache-Control", "no-store") - logNonce(w, nonce) + logNonce(w, string(nonce)) next(w, r) } } @@ -78,8 +77,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // directory index url. func (h *Handler) addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), - acme.DirectoryLink, true), "index")) + w.Header().Add("Link", link(h.linker.GetLink(r.Context(), + DirectoryLinkType, true), "index")) next(w, r) } } @@ -90,7 +89,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ct := r.Header.Get("Content-Type") var expected []string - if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) { + if strings.Contains(r.URL.Path, h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { @@ -103,8 +102,8 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.MalformedErr(errors.Errorf( - "expected content-type to be in %s, but got %s", expected, ct))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -113,15 +112,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body"))) + api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } - ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws) + ctx := context.WithValue(r.Context(), jwsContextKey, jws) next(w, r.WithContext(ctx)) } } @@ -143,17 +142,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below func (h *Handler) validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + ctx := r.Context() + jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -164,7 +164,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -174,25 +174,26 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+ - "keys must be at least %d bits (%d bytes) in size", - 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "rsa keys must be at least %d bits (%d bytes) in size", + 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "jws key type and algorithm do not match")) return } } case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. - if err := h.Auth.UseNonce(hdr.Nonce); err != nil { + if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { api.WriteError(w, err) return } @@ -200,21 +201,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -227,22 +229,27 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := acme.JwsFromContext(r.Context()) + jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + return + } + ctx = context.WithValue(ctx, jwkContextKey, jwk) + kid, err := acme.KeyToID(jwk) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - acc, err := h.Auth.GetAccountByKey(ctx, jwk) + acc, err := h.db.GetAccountByKeyID(ctx, kid) switch { case nosql.IsErrNotFound(err): // For NewAccount requests ... @@ -252,10 +259,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return default: if !acc.IsValid() { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, accContextKey, acc) } next(w, r.WithContext(ctx)) } @@ -270,20 +277,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { name := chi.URLParam(r, "provisionerID") provID, err := url.PathUnescape(name) if err != nil { - api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name))) + api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name)) return } - p, err := h.Auth.LoadProvisionerByID("acme/" + provID) + p, err := h.ca.LoadProvisionerByID("acme/" + provID) if err != nil { api.WriteError(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } - ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv)) + ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) next(w, r.WithContext(ctx)) } } @@ -294,36 +301,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := acme.JwsFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "") + kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ - "required prefix; expected %s, but got %s", kidPrefix, kid))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "kid does not have required prefix; expected %s, but got %s", + kidPrefix, kid)) return } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.Auth.GetAccount(r.Context(), accID) + acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: api.WriteError(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, jwkContextKey, acc.Key) next(w, r.WithContext(ctx)) return } @@ -334,26 +342,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + ctx := r.Context() + jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - jwk, err := acme.JwkFromContext(r.Context()) + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } - ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{ + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ value: payload, isPostAsGet: string(payload) == "", isEmptyJSON: string(payload) == "{}", @@ -371,9 +380,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return } if !payload.isPostAsGet { - api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) } } + +// ContextKey is the key type for storing and searching for ACME request +// essentials in the context of a request. +type ContextKey string + +const ( + // accContextKey account key + accContextKey = ContextKey("acc") + // baseURLContextKey baseURL key + baseURLContextKey = ContextKey("baseURL") + // jwsContextKey jws key + jwsContextKey = ContextKey("jws") + // jwkContextKey jwk key + jwkContextKey = ContextKey("jwk") + // payloadContextKey payload key + payloadContextKey = ContextKey("payload") + // provisionerContextKey provisioner key + provisionerContextKey = ContextKey("provisioner") +) + +// accountFromContext searches the context for an ACME account. Returns the +// account or an error. +func accountFromContext(ctx context.Context) (*acme.Account, error) { + val, ok := ctx.Value(accContextKey).(*acme.Account) + if !ok || val == nil { + return nil, acme.NewErrorISE("account not in context") + } + return val, nil +} + +// baseURLFromContext returns the baseURL if one is stored in the context. +func baseURLFromContext(ctx context.Context) *url.URL { + val, ok := ctx.Value(baseURLContextKey).(*url.URL) + if !ok || val == nil { + return nil + } + return val +} + +// jwkFromContext searches the context for a JWK. Returns the JWK or an error. +func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { + val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) + if !ok || val == nil { + return nil, acme.NewErrorISE("jwk expected in request context") + } + return val, nil +} + +// jwsFromContext searches the context for a JWS. Returns the JWS or an error. +func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { + val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature) + if !ok || val == nil { + return nil, acme.NewErrorISE("jws expected in request context") + } + return val, nil +} + +// provisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { + val := ctx.Value(provisionerContextKey) + if val == nil { + return nil, acme.NewErrorISE("provisioner expected in request context") + } + pval, ok := val.(acme.Provisioner) + if !ok || pval == nil { + return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") + } + return pval, nil +} + +// payloadFromContext searches the context for a payload. Returns the payload +// or an error. +func payloadFromContext(ctx context.Context) (*payloadInfo, error) { + val, ok := ctx.Value(payloadContextKey).(*payloadInfo) + if !ok || val == nil { + return nil, acme.NewErrorISE("payload expected in request context") + } + return val, nil +} diff --git a/acme/api/order.go b/acme/api/order.go index 1fead85c..2bf7d2ef 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -1,16 +1,18 @@ package api import ( + "context" "crypto/x509" "encoding/base64" "encoding/json" "net/http" + "strings" "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "go.step.sm/crypto/randutil" ) // NewOrderRequest represents the body for a NewOrder request. @@ -23,11 +25,11 @@ type NewOrderRequest struct { // Validate validates a new-order request body. func (n *NewOrderRequest) Validate() error { if len(n.Identifiers) == 0 { - return acme.NewError(ErrorMalformedType, "identifiers list cannot be empty") + return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { if id.Type != "dns" { - return acme.NewError(ErrorMalformedType, "identifier type unsupported: %s", id.Type) + return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } } return nil @@ -44,22 +46,29 @@ func (f *FinalizeRequest) Validate() error { var err error csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) if err != nil { - return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr")) + return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr") } f.csr, err = x509.ParseCertificateRequest(csrBytes) if err != nil { - return acme.MalformedErr(errors.Wrap(err, "unable to parse csr")) + return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr") } if err = f.csr.CheckSignature(); err != nil { - return acme.MalformedErr(errors.Wrap(err, "csr failed signature check")) + return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check") } return nil } +var defaultOrderExpiry = time.Hour * 24 + // NewOrder ACME api for creating a new order. func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -71,8 +80,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, - "failed to unmarshal new-order request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { @@ -80,44 +89,133 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{ - AccountID: acc.GetID(), - Identifiers: nor.Identifiers, - NotBefore: nor.NotBefore, - NotAfter: nor.NotAfter, - }) - if err != nil { - api.WriteError(w, err) + // New order. + o := &acme.Order{Identifiers: nor.Identifiers} + + o.AuthorizationIDs = make([]string, len(o.Identifiers)) + for i, identifier := range o.Identifiers { + az := &acme.Authorization{ + AccountID: acc.ID, + Identifier: identifier, + } + if err := h.newAuthorization(ctx, az); err != nil { + api.WriteError(w, err) + return + } + o.AuthorizationIDs[i] = az.ID + } + + now := clock.Now() + if o.NotBefore.IsZero() { + o.NotBefore = now + } + if o.NotAfter.IsZero() { + o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) + } + o.Expires = now.Add(defaultOrderExpiry) + + if err := h.db.CreateOrder(ctx, o); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) return } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) + h.linker.Link(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSONStatus(w, o, http.StatusCreated) } +func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { + if strings.HasPrefix(az.Identifier.Value, "*.") { + az.Wildcard = true + az.Identifier = acme.Identifier{ + Value: strings.TrimPrefix(az.Identifier.Value, "*."), + Type: az.Identifier.Type, + } + } + + var ( + err error + chTypes = []string{"dns-01"} + ) + // HTTP and TLS challenges can only be used for identifiers without wildcards. + if !az.Wildcard { + chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) + } + + az.Token, err = randutil.Alphanumeric(32) + if err != nil { + return acme.WrapErrorISE(err, "error generating random alphanumeric ID") + } + + az.Challenges = make([]*acme.Challenge, len(chTypes)) + for i, typ := range chTypes { + ch := &acme.Challenge{ + AccountID: az.AccountID, + AuthzID: az.ID, + Value: az.Identifier.Value, + Type: typ, + Token: az.Token, + } + if err := h.db.CreateChallenge(ctx, ch); err != nil { + return err + } + az.Challenges[i] = ch + } + if err = h.db.CreateAuthorization(ctx, az); err != nil { + return acme.WrapErrorISE(err, "error creating authorization") + } + return nil +} + // GetOrder ACME api for retrieving an order. func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - oid := chi.URLParam(r, "ordID") - o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid) + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return } + o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + return + } + if acc.ID != o.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own order '%s'", acc.ID, o.ID)) + return + } + if prov.GetID() != o.ProvisionerID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) + return + } + if err = o.UpdateStatus(ctx, h.db); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + return + } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) + h.linker.LinkOrder(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -129,7 +227,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { @@ -137,13 +236,28 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - oid := chi.URLParam(r, "ordID") - o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr) + o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) return } + if acc.ID != o.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own order '%s'", acc.ID, o.ID)) + return + } + if prov.GetID() != o.ProvisionerID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) + return + } + if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) + return + } + + h.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSON(w, o) } diff --git a/acme/authority.go b/acme/authority.go deleted file mode 100644 index 92e1c8f7..00000000 --- a/acme/authority.go +++ /dev/null @@ -1,420 +0,0 @@ -package acme - -import ( - "context" - "crypto/tls" - "crypto/x509" - "log" - "net" - "net/http" - "net/url" - "strings" - "time" - - "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/randutil" -) - -// Interface is the acme authority interface. -type Interface interface { - GetDirectory(ctx context.Context) (*Directory, error) - NewNonce() (string, error) - UseNonce(string) error - - DeactivateAccount(ctx context.Context, accID string) (*Account, error) - GetAccount(ctx context.Context, accID string) (*Account, error) - GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error) - NewAccount(ctx context.Context, acc *Account) (*Account, error) - UpdateAccount(ctx context.Context, acc *Account) (*Account, error) - - GetAuthz(ctx context.Context, accID string, authzID string) (*Authorization, error) - ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error) - - FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error) - GetOrder(ctx context.Context, accID string, orderID string) (*Order, error) - GetOrdersByAccount(ctx context.Context, accID string) ([]string, error) - NewOrder(ctx context.Context, o *Order) (*Order, error) - - GetCertificate(string, string) ([]byte, error) - - LoadProvisionerByID(string) (provisioner.Interface, error) - GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string - GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string -} - -// Authority is the layer that handles all ACME interactions. -type Authority struct { - backdate provisioner.Duration - db DB - dir *directory - signAuth SignAuthority -} - -// AuthorityOptions required to create a new ACME Authority. -type AuthorityOptions struct { - Backdate provisioner.Duration - // DB storage backend that impements the acme.DB interface. - DB DB - // DNS the host used to generate accurate ACME links. By default the authority - // will use the Host from the request, so this value will only be used if - // request.Host is empty. - DNS string - // Prefix is a URL path prefix under which the ACME api is served. This - // prefix is required to generate accurate ACME links. - // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- - // "acme" is the prefix from which the ACME api is accessed. - Prefix string -} - -// NewAuthority returns a new Authority that implements the ACME interface. -// -// Deprecated: NewAuthority exists for hitorical compatibility and should not -// be used. Use acme.New() instead. -func NewAuthority(db DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { - return New(signAuth, AuthorityOptions{ - DB: db, - DNS: dns, - Prefix: prefix, - }) -} - -// New returns a new Authority that implements the ACME interface. -func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { - return &Authority{ - backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth, - }, nil -} - -// GetLink returns the requested link from the directory. -func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { - return a.dir.getLink(ctx, typ, abs, inputs...) -} - -// GetLinkExplicit returns the requested link from the directory. -func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string { - return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...) -} - -// GetDirectory returns the ACME directory object. -func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) { - return &Directory{ - NewNonce: a.dir.getLink(ctx, NewNonceLink, true), - NewAccount: a.dir.getLink(ctx, NewAccountLink, true), - NewOrder: a.dir.getLink(ctx, NewOrderLink, true), - RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true), - KeyChange: a.dir.getLink(ctx, KeyChangeLink, true), - }, nil -} - -// LoadProvisionerByID calls out to the SignAuthority interface to load a -// provisioner by ID. -func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { - return a.signAuth.LoadProvisionerByID(id) -} - -// NewNonce generates, stores, and returns a new ACME nonce. -func (a *Authority) NewNonce(ctx context.Context) (Nonce, error) { - return a.db.CreateNonce(ctx) -} - -// UseNonce consumes the given nonce if it is valid, returns error otherwise. -func (a *Authority) UseNonce(ctx context.Context, nonce string) error { - return a.db.DeleteNonce(ctx, Nonce(nonce)) -} - -// NewAccount creates, stores, and returns a new ACME account. -func (a *Authority) NewAccount(ctx context.Context, acc *Account) error { - if err := a.db.CreateAccount(ctx, acc); err != nil { - return ErrorISEWrap(err, "error creating account") - } - return nil -} - -// UpdateAccount updates an ACME account. -func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, error) { - /* - acc.Contact = auo.Contact - acc.Status = auo.Status - */ - if err := a.db.UpdateAccount(ctx, acc); err != nil { - return nil, ErrorISEWrap(err, "error updating account") - } - return acc, nil -} - -// GetAccount returns an ACME account. -func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { - acc, err := a.db.GetAccount(ctx, id) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving account") - } - return acc, nil -} - -// GetAccountByKey returns the ACME associated with the jwk id. -func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) { - kid, err := KeyToID(jwk) - if err != nil { - return nil, err - } - acc, err := a.db.GetAccountByKeyID(ctx, kid) - return acc, err -} - -// GetOrder returns an ACME order. -func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err - } - o, err := a.db.GetOrder(ctx, orderID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving order") - } - if accID != o.AccountID { - log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - if prov.GetID() != o.ProvisionerID { - log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) - return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") - } - if err = o.UpdateStatus(ctx, a.db); err != nil { - return nil, ErrorISEWrap(err, "error updating order") - } - return o, nil -} - -/* -// GetOrdersByAccount returns the list of order urls owned by the account. -func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - ordersByAccountMux.Lock() - defer ordersByAccountMux.Unlock() - - var oiba = orderIDsByAccount{} - oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id) - if err != nil { - return nil, err - } - - var ret = []string{} - for _, oid := range oids { - ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid)) - } - return ret, nil -} -*/ - -// NewOrder generates, stores, and returns a new ACME order. -func (a *Authority) NewOrder(ctx context.Context, o *Order) error { - if len(o.AccountID) == 0 { - return NewErrorISE("account-id cannot be empty") - } - if len(o.ProvisionerID) == 0 { - return NewErrorISE("provisioner-id cannot be empty") - } - if len(o.Identifiers) == 0 { - return NewErrorISE("identifiers cannot be empty") - } - if o.DefaultDuration == 0 { - return NewErrorISE("default-duration cannot be empty") - } - - o.AuthorizationIDs = make([]string, len(o.Identifiers)) - for i, identifier := range o.Identifiers { - az := &Authorization{ - AccountID: o.AccountID, - Identifier: identifier, - } - if err := a.NewAuthorization(ctx, az); err != nil { - return err - } - o.AuthorizationIDs[i] = az.ID - } - - now := clock.Now() - if o.NotBefore.IsZero() { - o.NotBefore = now - } - if o.NotAfter.IsZero() { - o.NotAfter = o.NotBefore.Add(o.DefaultDuration) - } - - if err := a.db.CreateOrder(ctx, o); err != nil { - return ErrorISEWrap(err, "error creating order") - } - return nil - /* - o.DefaultDuration = prov.DefaultTLSCertDuration() - o.Backdate = a.backdate.Duration - o.ProvisionerID = prov.GetID() - - if err = a.db.CreateOrder(ctx, o); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error creating order") - } - return o, nil - */ -} - -// FinalizeOrder attempts to finalize an order and generate a new certificate. -func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err - } - o, err := a.db.GetOrder(ctx, orderID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving order") - } - if accID != o.AccountID { - log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - if prov.GetID() != o.ProvisionerID { - log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) - return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") - } - if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil { - return nil, ErrorISEWrap(err, "error finalizing order") - } - return o, nil -} - -// NewAuthorization generates and stores an ACME Authorization type along with -// any associated resources. -func (a *Authority) NewAuthorization(ctx context.Context, az *Authorization) error { - if len(az.AccountID) == 0 { - return NewErrorISE("account-id cannot be empty") - } - if len(az.Identifier.Value) == 0 { - return NewErrorISE("identifier cannot be empty") - } - - if strings.HasPrefix(az.Identifier.Value, "*.") { - az.Wildcard = true - az.Identifier = Identifier{ - Value: strings.TrimPrefix(az.Identifier.Value, "*."), - Type: az.Identifier.Type, - } - } - - var ( - err error - chTypes = []string{"dns-01"} - ) - // HTTP and TLS challenges can only be used for identifiers without wildcards. - if !az.Wildcard { - chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) - } - - az.Token, err = randutil.Alphanumeric(32) - if err != nil { - return ErrorISEWrap(err, "error generating random alphanumeric ID") - } - - az.Challenges = make([]*Challenge, len(chTypes)) - for i, typ := range chTypes { - ch := &Challenge{ - AccountID: az.AccountID, - AuthzID: az.ID, - Value: az.Identifier.Value, - Type: typ, - Token: az.Token, - } - if err := a.NewChallenge(ctx, ch); err != nil { - return err - } - az.Challenges[i] = ch - } - if err = a.db.CreateAuthorization(ctx, az); err != nil { - return ErrorISEWrap(err, "error creating authorization") - } - return nil -} - -// GetAuthorization retrieves and attempts to update the status on an ACME authz -// before returning. -func (a *Authority) GetAuthorization(ctx context.Context, accID, authzID string) (*Authorization, error) { - az, err := a.db.GetAuthorization(ctx, authzID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving authorization") - } - if accID != az.AccountID { - log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - if err = az.UpdateStatus(ctx, a.db); err != nil { - return nil, ErrorISEWrap(err, "error updating authorization status") - } - return az, nil -} - -// NewChallenge generates and stores an ACME challenge and associated resources. -func (a *Authority) NewChallenge(ctx context.Context, ch *Challenge) error { - if len(ch.AccountID) == 0 { - return NewErrorISE("account-id cannot be empty") - } - if len(ch.AuthzID) == 0 { - return NewErrorISE("authz-id cannot be empty") - } - if len(ch.Token) == 0 { - return NewErrorISE("token cannot be empty") - } - if len(ch.Value) == 0 { - return NewErrorISE("value cannot be empty") - } - - switch ch.Type { - case "dns-01", "http-01", "tls-alpn-01": - break - default: - return NewErrorISE("unexpected error type '%s'", ch.Type) - } - - if err := a.db.CreateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "error creating challenge") - } - return nil -} - -// GetValidateChallenge attempts to validate the challenge. -func (a *Authority) GetValidateChallenge(ctx context.Context, accID, chID, azID string, jwk *jose.JSONWebKey) (*Challenge, error) { - ch, err := a.db.GetChallenge(ctx, chID, "todo") - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving challenge") - } - if accID != ch.AccountID { - log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, ch.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - client := http.Client{ - Timeout: time.Duration(30 * time.Second), - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - if err = ch.Validate(ctx, a.db, jwk, validateOptions{ - httpGet: client.Get, - lookupTxt: net.LookupTXT, - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }); err != nil { - return nil, ErrorISEWrap(err, "error validating challenge") - } - return ch, nil -} - -// GetCertificate retrieves the Certificate by ID. -func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { - cert, err := a.db.GetCertificate(ctx, certID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving certificate") - } - if cert.AccountID != accID { - log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - return cert.ToACME(ctx) -} diff --git a/acme/authorization.go b/acme/authorization.go index 7f15f4c6..df4ac229 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -22,7 +22,7 @@ type Authorization struct { func (az *Authorization) ToLog() (interface{}, error) { b, err := json.Marshal(az) if err != nil { - return nil, ErrorISEWrap(err, "error marshaling authz for logging") + return nil, WrapErrorISE(err, "error marshaling authz for logging") } return string(b), nil } @@ -30,11 +30,7 @@ func (az *Authorization) ToLog() (interface{}, error) { // UpdateStatus updates the ACME Authorization Status if necessary. // Changes to the Authorization are saved using the database interface. func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { - now := time.Now().UTC() - expiry, err := time.Parse(time.RFC3339, az.Expires) - if err != nil { - return ErrorISEWrap(err, "error converting expiry string to time") - } + now := clock.Now() switch az.Status { case StatusInvalid: @@ -43,7 +39,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusPending: // check expiry - if now.After(expiry) { + if now.After(az.Expires) { az.Status = StatusInvalid break } @@ -61,11 +57,11 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { } az.Status = StatusValid default: - return NewError(ErrorServerInternalType, "unrecognized authorization status: %s", az.Status) + return NewErrorISE("unrecognized authorization status: %s", az.Status) } - if err = db.UpdateAuthorization(ctx, az); err != nil { - return ErrorISEWrap(err, "error updating authorization") + if err := db.UpdateAuthorization(ctx, az); err != nil { + return WrapErrorISE(err, "error updating authorization") } return nil } diff --git a/acme/certificate.go b/acme/certificate.go index 356c0121..daf9556b 100644 --- a/acme/certificate.go +++ b/acme/certificate.go @@ -1,7 +1,6 @@ package acme import ( - "context" "crypto/x509" "encoding/pem" ) @@ -16,7 +15,7 @@ type Certificate struct { } // ToACME encodes the entire X509 chain into a PEM list. -func (cert *Certificate) ToACME(ctx context.Context) ([]byte, error) { +func (cert *Certificate) ToACME() ([]byte, error) { var ret []byte for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { ret = append(ret, pem.EncodeToMemory(&pem.Block{ diff --git a/acme/challenge.go b/acme/challenge.go index ca2e5562..2abc808c 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -38,7 +38,7 @@ type Challenge struct { func (ch *Challenge) ToLog() (interface{}, error) { b, err := json.Marshal(ch) if err != nil { - return nil, ErrorISEWrap(err, "error marshaling challenge for logging") + return nil, WrapErrorISE(err, "error marshaling challenge for logging") } return string(b), nil } @@ -47,7 +47,7 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { // If already valid or invalid then return without performing validation. if ch.Status == StatusValid || ch.Status == StatusInvalid { return nil @@ -60,16 +60,16 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, case "tls-alpn-01": return tlsalpn01Validate(ctx, ch, db, jwk, vo) default: - return NewError(ErrorServerInternalType, "unexpected challenge type '%s'", ch.Type) + return NewErrorISE("unexpected challenge type '%s'", ch.Type) } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token) - resp, err := vo.httpGet(url) + resp, err := vo.HTTPGet(url) if err != nil { - return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, + return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", url)) } if resp.StatusCode >= 400 { @@ -80,7 +80,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb body, err := ioutil.ReadAll(resp.Body) if err != nil { - return ErrorISEWrap(err, "error reading "+ + return WrapErrorISE(err, "error reading "+ "response body for url %s", url) } keyAuth := strings.Trim(string(body), "\r\n") @@ -100,12 +100,12 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "error updating challenge") + return WrapErrorISE(err, "error updating challenge") } return nil } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, ServerName: ch.Value, @@ -114,9 +114,9 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.tlsDial("tcp", hostPort, config) + conn, err := vo.TLSDial("tcp", hostPort, config) if err != nil { - return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, + return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err, "error doing TLS dial for %s", hostPort)) } defer conn.Close() @@ -178,7 +178,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "tlsalpn01ValidateChallenge - error updating challenge") + return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge") } return nil } @@ -197,16 +197,16 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) + txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) if err != nil { - return storeError(ctx, ch, db, ErrorWrap(ErrorDNSType, err, + return storeError(ctx, ch, db, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) } @@ -234,7 +234,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK ch.Validated = clock.Now().UTC().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "error updating challenge") + return WrapErrorISE(err, "error updating challenge") } return nil } @@ -244,7 +244,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return "", ErrorISEWrap(err, "error generating JWK thumbprint") + return "", WrapErrorISE(err, "error generating JWK thumbprint") } encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) return fmt.Sprintf("%s.%s", token, encPrint), nil @@ -254,7 +254,7 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { ch.Error = err if err := db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "failure saving error to acme challenge") + return WrapErrorISE(err, "failure saving error to acme challenge") } return nil } @@ -263,8 +263,9 @@ type httpGetter func(string) (*http.Response, error) type lookupTxt func(string) ([]string, error) type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) -type validateOptions struct { - httpGet httpGetter - lookupTxt lookupTxt - tlsDial tlsDialer +// ValidateOptions are ACME challenge validator functions. +type ValidateOptions struct { + HTTPGet httpGetter + LookupTxt lookupTxt + TLSDial tlsDialer } diff --git a/acme/common.go b/acme/common.go index b9dc6ff2..f7fd7141 100644 --- a/acme/common.go +++ b/acme/common.go @@ -3,13 +3,27 @@ package acme import ( "context" "crypto/x509" - "net/url" "time" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" ) +// CertificateAuthority is the interface implemented by a CA authority. +type CertificateAuthority interface { + Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + LoadProvisionerByID(string) (provisioner.Interface, error) +} + +// Clock that returns time in UTC rounded to seconds. +type Clock int + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Round(time.Second) +} + +var clock = new(Clock) + // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { @@ -70,89 +84,3 @@ func (m *MockProvisioner) GetID() string { } return m.Mret1.(string) } - -// ContextKey is the key type for storing and searching for ACME request -// essentials in the context of a request. -type ContextKey string - -const ( - // AccContextKey account key - AccContextKey = ContextKey("acc") - // BaseURLContextKey baseURL key - BaseURLContextKey = ContextKey("baseURL") - // JwsContextKey jws key - JwsContextKey = ContextKey("jws") - // JwkContextKey jwk key - JwkContextKey = ContextKey("jwk") - // PayloadContextKey payload key - PayloadContextKey = ContextKey("payload") - // ProvisionerContextKey provisioner key - ProvisionerContextKey = ContextKey("provisioner") -) - -// AccountFromContext searches the context for an ACME account. Returns the -// account or an error. -func AccountFromContext(ctx context.Context) (*Account, error) { - val, ok := ctx.Value(AccContextKey).(*Account) - if !ok || val == nil { - return nil, NewError(ErrorServerInternalType, "account not in context") - } - return val, nil -} - -// BaseURLFromContext returns the baseURL if one is stored in the context. -func BaseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(BaseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - -// JwkFromContext searches the context for a JWK. Returns the JWK or an error. -func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { - val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey) - if !ok || val == nil { - return nil, NewError(ErrorServerInternalType, "jwk expected in request context") - } - return val, nil -} - -// JwsFromContext searches the context for a JWS. Returns the JWS or an error. -func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { - val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature) - if !ok || val == nil { - return nil, NewError(ErrorServerInternalType, "jws expected in request context") - } - return val, nil -} - -// ProvisionerFromContext searches the context for a provisioner. Returns the -// provisioner or an error. -func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { - val := ctx.Value(ProvisionerContextKey) - if val == nil { - return nil, NewError(ErrorServerInternalType, "provisioner expected in request context") - } - pval, ok := val.(Provisioner) - if !ok || pval == nil { - return nil, NewError(ErrorServerInternalType, "provisioner in context is not an ACME provisioner") - } - return pval, nil -} - -// SignAuthority is the interface implemented by a CA authority. -type SignAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - LoadProvisionerByID(string) (provisioner.Interface, error) -} - -// Clock that returns time in UTC rounded to seconds. -type Clock int - -// Now returns the UTC time rounded to seconds. -func (c *Clock) Now() time.Time { - return time.Now().UTC().Round(time.Second) -} - -var clock = new(Clock) diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 40961ce3..befeb54d 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -74,7 +74,6 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) return &acme.Account{ Status: dbacc.Status, Contact: dbacc.Contact, - Orders: dir.getLink(ctx, OrdersByAccountLink, true, dbacc.ID), Key: dbacc.Key, ID: dbacc.ID, }, nil diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 818f5c2d..0992509d 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -14,15 +14,15 @@ var defaultExpiryDuration = time.Hour * 24 // dbAuthz is the base authz type that others build from. type dbAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier *acme.Identifier `json:"identifier"` - Status acme.Status `json:"status"` - Expires time.Time `json:"expires"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` - Error *acme.Error `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier acme.Identifier `json:"identifier"` + Status acme.Status `json:"status"` + Expires time.Time `json:"expires"` + Challenges []string `json:"challenges"` + Wildcard bool `json:"wildcard"` + Created time.Time `json:"created"` + Error *acme.Error `json:"error"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -66,7 +66,7 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat Status: dbaz.Status, Challenges: chs, Wildcard: dbaz.Wildcard, - Expires: dbaz.Expires.Format(time.RFC3339), + Expires: dbaz.Expires, ID: dbaz.ID, }, nil } diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index 378b1f7b..48340cf4 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -21,7 +21,7 @@ type dbChallenge struct { Value string `json:"value"` Validated string `json:"validated"` Created time.Time `json:"created"` - Error *AError `json:"error"` + Error *acme.Error `json:"error"` } func (dbc *dbChallenge) clone() *dbChallenge { @@ -79,7 +79,6 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall Type: dbch.Type, Status: dbch.Status, Token: dbch.Token, - URL: dir.getLink(ctx, ChallengeLink, true, dbch.ID), ID: dbch.ID, AuthzID: dbch.AuthzID, Error: dbch.Error, diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index d2146e22..2f5ee11b 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -11,8 +11,6 @@ import ( "github.com/smallstep/nosql" ) -var defaultOrderExpiry = time.Hour * 24 - // Mutex for locking ordersByAccount index operations. var ordersByAccountMux sync.Mutex @@ -26,16 +24,16 @@ type dbOrder struct { Identifiers []acme.Identifier `json:"identifiers"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` - Error *Error `json:"error,omitempty"` + Error *acme.Error `json:"error,omitempty"` Authorizations []string `json:"authorizations"` - Certificate string `json:"certificate,omitempty"` + CertificateID string `json:"certificate,omitempty"` } // getDBOrder retrieves and unmarshals an ACME Order type from the database. func (db *DB) getDBOrder(id string) (*dbOrder, error) { b, err := db.db.Get(orderTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "order %s not found", id) + return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading order %s", id) } @@ -49,34 +47,31 @@ func (db *DB) getDBOrder(id string) (*dbOrder, error) { // GetOrder retrieves an ACME Order from the database. func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { dbo, err := db.getDBOrder(id) - - azs := make([]string, len(dbo.Authorizations)) - for i, aid := range dbo.Authorizations { - azs[i] = dir.getLink(ctx, AuthzLink, true, aid) + if err != nil { + return nil, err } + o := &acme.Order{ - Status: dbo.Status, - Expires: dbo.Expires.Format(time.RFC3339), - Identifiers: dbo.Identifiers, - NotBefore: dbo.NotBefore.Format(time.RFC3339), - NotAfter: dbo.NotAfter.Format(time.RFC3339), - Authorizations: azs, - FinalizeURL: dir.getLink(ctx, FinalizeLink, true, o.ID), - ID: dbo.ID, - ProvisionerID: dbo.ProvisionerID, + Status: dbo.Status, + Expires: dbo.Expires, + Identifiers: dbo.Identifiers, + NotBefore: dbo.NotBefore, + NotAfter: dbo.NotAfter, + AuthorizationIDs: dbo.Authorizations, + ID: dbo.ID, + ProvisionerID: dbo.ProvisionerID, + CertificateID: dbo.CertificateID, } - if dbo.Certificate != "" { - o.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate) - } return o, nil } // CreateOrder creates ACME Order resources and saves them to the DB. func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { + var err error o.ID, err = randID() if err != nil { - return nil, err + return err } now := clock.Now() @@ -85,23 +80,23 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { AccountID: o.AccountID, ProvisionerID: o.ProvisionerID, Created: now, - Status: StatusPending, - Expires: now.Add(defaultOrderExpiry), + Status: acme.StatusPending, + Expires: o.Expires, Identifiers: o.Identifiers, NotBefore: o.NotBefore, NotAfter: o.NotBefore, Authorizations: o.AuthorizationIDs, } - if err := db.save(ctx, o.ID, dbo, nil, orderTable); err != nil { - return nil, err + if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { + return err } var oidHelper = orderIDsByAccount{} _, err = oidHelper.addOrderID(db, o.AccountID, o.ID) if err != nil { - return nil, err + return err } - return o, nil + return nil } type orderIDsByAccount struct{} @@ -135,11 +130,11 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri if nosql.IsErrNotFound(err) { return []string{}, nil } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) + return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) } var oids []string if err := json.Unmarshal(b, &oids); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) + return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } // Remove any order that is not in PENDING state and update the stored list @@ -152,21 +147,21 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri for _, oid := range oids { o, err := getOrder(db, oid) if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) + return nil, errors.Wrapf(err, "error loading order %s for account %s", oid, accID) } if o, err = o.UpdateStatus(db); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) + return nil, errors.Wrapf(err, "error updating order %s for account %s", oid, accID) } - if o.Status == StatusPending { + if o.Status == acme.StatusPending { pendOids = append(pendOids, oid) } } // If the number of pending orders is less than the number of orders in the // list, then update the pending order list. if len(pendOids) != len(oids) { - if err = orderIDs(pendOiUs).save(db, oids, accID); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids))) + if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + return nil, errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ + "len(orderIDs) = %d", len(pendOids)) } } @@ -192,7 +187,7 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { } else { oldb, err = json.Marshal(old) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) + return errors.Wrap(err, "error marshaling old order IDs slice") } } if len(oids) == 0 { @@ -200,13 +195,13 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { } else { newb, err = json.Marshal(oids) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) + return errors.Wrap(err, "error marshaling new order IDs slice") } } _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) switch { case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) + return errors.Wrapf(err, "error storing order IDs for account %s", accID) case !swapped: return ServerInternalErr(errors.Errorf("error storing order IDs "+ "for account %s; order IDs changed since last read", accID)) diff --git a/acme/directory.go b/acme/directory.go deleted file mode 100644 index 8520d0e9..00000000 --- a/acme/directory.go +++ /dev/null @@ -1,148 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "fmt" - "net/url" -) - -// Directory represents an ACME directory for configuring clients. -type Directory struct { - NewNonce string `json:"newNonce,omitempty"` - NewAccount string `json:"newAccount,omitempty"` - NewOrder string `json:"newOrder,omitempty"` - NewAuthz string `json:"newAuthz,omitempty"` - RevokeCert string `json:"revokeCert,omitempty"` - KeyChange string `json:"keyChange,omitempty"` -} - -// ToLog enables response logging for the Directory type. -func (d *Directory) ToLog() (interface{}, error) { - b, err := json.Marshal(d) - if err != nil { - return nil, ErrorISEWrap(err, "error marshaling directory for logging") - } - return string(b), nil -} - -type directory struct { - prefix, dns string -} - -// newDirectory returns a new Directory type. -func newDirectory(dns, prefix string) *directory { - return &directory{prefix: prefix, dns: dns} -} - -// Link captures the link type. -type Link int - -const ( - // NewNonceLink new-nonce - NewNonceLink Link = iota - // NewAccountLink new-account - NewAccountLink - // AccountLink account - AccountLink - // OrderLink order - OrderLink - // NewOrderLink new-order - NewOrderLink - // OrdersByAccountLink list of orders owned by account - OrdersByAccountLink - // FinalizeLink finalize order - FinalizeLink - // NewAuthzLink authz - NewAuthzLink - // AuthzLink new-authz - AuthzLink - // ChallengeLink challenge - ChallengeLink - // CertificateLink certificate - CertificateLink - // DirectoryLink directory - DirectoryLink - // RevokeCertLink revoke certificate - RevokeCertLink - // KeyChangeLink key rollover - KeyChangeLink -) - -func (l Link) String() string { - switch l { - case NewNonceLink: - return "new-nonce" - case NewAccountLink: - return "new-account" - case AccountLink: - return "account" - case NewOrderLink: - return "new-order" - case OrderLink: - return "order" - case NewAuthzLink: - return "new-authz" - case AuthzLink: - return "authz" - case ChallengeLink: - return "challenge" - case CertificateLink: - return "certificate" - case DirectoryLink: - return "directory" - case RevokeCertLink: - return "revoke-cert" - case KeyChangeLink: - return "key-change" - default: - return "unexpected" - } -} - -func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { - var provName string - if p, err := ProvisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...) -} - -// getLinkExplicit returns an absolute or partial path to the given resource and a base -// URL dynamically obtained from the request for which the link is being -// calculated. -func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { - var link string - switch typ { - case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: - link = fmt.Sprintf("/%s/%s", provisionerName, typ.String()) - case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink: - link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0]) - case OrdersByAccountLink: - link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0]) - case FinalizeLink: - link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) - } - - if abs { - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - u := url.URL{} - if baseURL != nil { - u = *baseURL - } - - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } - - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = d.dns - } - - u.Path = d.prefix + link - return u.String() - } - return link -} diff --git a/acme/errors.go b/acme/errors.go index 8fe2559d..41305c87 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -262,7 +262,7 @@ var ( // Error represents an ACME type Error struct { Type string `json:"type"` - Details string `json:"detail"` + Detail string `json:"detail"` Subproblems []interface{} `json:"subproblems,omitempty"` Identifier interface{} `json:"identifier,omitempty"` Err error `json:"-"` @@ -275,18 +275,18 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error { if !ok { meta = errorServerInternalMetadata return &Error{ - Type: meta.typ, - Details: meta.details, - Status: meta.status, - Err: errors.Errorf("unrecognized problemType %v", pt), + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: errors.Errorf("unrecognized problemType %v", pt), } } return &Error{ - Type: meta.typ, - Details: meta.details, - Status: meta.status, - Err: errors.Errorf(msg, args...), + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: errors.Errorf(msg, args...), } } @@ -295,14 +295,14 @@ func NewErrorISE(msg string, args ...interface{}) *Error { return NewError(ErrorServerInternalType, msg, args...) } -// ErrorWrap attempts to wrap the internal error. -func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Error { +// WrapError attempts to wrap the internal error. +func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error { switch e := err.(type) { case nil: return nil case *Error: if e.Err == nil { - e.Err = errors.Errorf(msg+"; "+e.Details, args...) + e.Err = errors.Errorf(msg+"; "+e.Detail, args...) } else { e.Err = errors.Wrapf(e.Err, msg, args...) } @@ -312,9 +312,9 @@ func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Err } } -// ErrorISEWrap shortcut to wrap an internal server error type. -func ErrorISEWrap(err error, msg string, args ...interface{}) *Error { - return ErrorWrap(ErrorServerInternalType, err, msg, args...) +// WrapErrorISE shortcut to wrap an internal server error type. +func WrapErrorISE(err error, msg string, args ...interface{}) *Error { + return WrapError(ErrorServerInternalType, err, msg, args...) } // StatusCode returns the status code and implements the StatusCoder interface. @@ -324,13 +324,13 @@ func (e *Error) StatusCode() int { // Error allows AError to implement the error interface. func (e *Error) Error() string { - return e.Details + return e.Detail } // Cause returns the internal error and implements the Causer interface. func (e *Error) Cause() error { if e.Err == nil { - return errors.New(e.Details) + return errors.New(e.Detail) } return e.Err } diff --git a/acme/nonce.go b/acme/nonce.go index 4234e818..25c86360 100644 --- a/acme/nonce.go +++ b/acme/nonce.go @@ -1,3 +1,9 @@ package acme +// Nonce represents an ACME nonce type. type Nonce string + +// String implements the ToString interface. +func (n Nonce) String() string { + return string(n) +} diff --git a/acme/order.go b/acme/order.go index bf3297f9..1719d899 100644 --- a/acme/order.go +++ b/acme/order.go @@ -26,10 +26,11 @@ type Order struct { NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` Error interface{} `json:"error,omitempty"` - AuthorizationURLs []string `json:"authorizations"` AuthorizationIDs []string `json:"-"` + AuthorizationURLs []string `json:"authorizations"` FinalizeURL string `json:"finalize"` - Certificate string `json:"certificate,omitempty"` + CertificateID string `json:"-"` + CertificateURL string `json:"certificate,omitempty"` ID string `json:"-"` AccountID string `json:"-"` ProvisionerID string `json:"-"` @@ -41,7 +42,7 @@ type Order struct { func (o *Order) ToLog() (interface{}, error) { b, err := json.Marshal(o) if err != nil { - return nil, ErrorISEWrap(err, "error marshaling order for logging") + return nil, WrapErrorISE(err, "error marshaling order for logging") } return string(b), nil } @@ -111,7 +112,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { // Finalize signs a certificate if the necessary conditions for Order completion // have been met. -func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { +func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error { if err := o.UpdateStatus(ctx, db); err != nil { return err } @@ -170,7 +171,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return ErrorISEWrap(err, "error retrieving authorization options from ACME provisioner") + return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner") } // Template data @@ -180,7 +181,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return ErrorISEWrap(err, "error creating template options from ACME provisioner") + return WrapErrorISE(err, "error creating template options from ACME provisioner") } signOps = append(signOps, templateOptions) @@ -190,7 +191,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) if err != nil { - return ErrorISEWrap(err, "error signing certificate for order %s", o.ID) + return WrapErrorISE(err, "error signing certificate for order %s", o.ID) } cert := &Certificate{ @@ -203,7 +204,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques return err } - o.Certificate = cert.ID + o.CertificateID = cert.ID o.Status = StatusValid return db.UpdateOrder(ctx, o) } diff --git a/ca/acmeClient.go b/ca/acmeClient.go index deb8a3a2..b19ad664 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -21,7 +21,7 @@ import ( type ACMEClient struct { client *http.Client dirLoc string - dir *acme.Directory + dir *acmeAPI.Directory acc *acme.Account Key *jose.JSONWebKey kid string @@ -53,7 +53,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } - var dir acme.Directory + var dir acmeAPI.Directory if err := readJSON(resp.Body, &dir); err != nil { return nil, errors.Wrapf(err, "error reading %s", endpoint) } @@ -93,7 +93,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC // GetDirectory makes a directory request to the ACME api and returns an // ACME directory object. -func (c *ACMEClient) GetDirectory() (*acme.Directory, error) { +func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) { return c.dir, nil } @@ -231,7 +231,7 @@ func (c *ACMEClient) ValidateChallenge(url string) error { } // GetAuthz returns the Authz at the given path. -func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { +func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, err @@ -240,7 +240,7 @@ func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { return nil, readACMEError(resp.Body) } - var az acme.Authz + var az acme.Authorization if err := readJSON(resp.Body, &az); err != nil { return nil, errors.Wrapf(err, "error reading %s", url) } @@ -342,7 +342,7 @@ func readACMEError(r io.ReadCloser) error { if err != nil { return errors.Wrap(err, "error reading from body") } - ae := new(acme.AError) + ae := new(acme.Error) err = json.Unmarshal(b, &ae) // If we successfully marshaled to an ACMEError then return the ACMEError. if err != nil || len(ae.Error()) == 0 { diff --git a/ca/ca.go b/ca/ca.go index 5ba81e9e..5ebc0919 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -11,8 +11,8 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" - "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" + acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/db" @@ -124,11 +124,12 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } prefix := "acme" - acmeAuth, err := acme.New(auth, acme.AuthorityOptions{ + acmeAuth, err := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ Backdate: *config.AuthorityConfig.Backdate, - DB: auth.GetDatabase().(nosql.DB), + DB: acmeNoSQL.New(auth.GetDatabase().(nosql.DB)), DNS: dns, Prefix: prefix, + CA: auth, }) if err != nil { return nil, errors.Wrap(err, "error creating ACME authority") From 116869ebc5ca3bbfa07946caf9cbb8fa842703dc Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 4 Mar 2021 23:14:56 -0800 Subject: [PATCH 12/47] [acme db interface] wip --- acme/api/account.go | 17 +++++++++-------- acme/api/linker.go | 10 +++++----- acme/api/order.go | 2 +- acme/db/nosql/order.go | 4 ++-- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index 16cc1f79..6f2a5f96 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -65,7 +65,8 @@ func (u *UpdateAccountRequest) Validate() error { // NewAccount is the handler resource for creating new ACME accounts. func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { - payload, err := payloadFromContext(r.Context()) + ctx := r.Context() + payload, err := payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -97,7 +98,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { "account does not exist")) return } - jwk, err := jwkFromContext(r.Context()) + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -108,7 +109,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Contact: nar.Contact, Status: acme.StatusValid, } - if err := h.db.CreateAccount(r.Context(), acc); err != nil { + if err := h.db.CreateAccount(ctx, acc); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) return } @@ -126,12 +127,13 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { // GetUpdateAccount is the api for updating an ACME account. func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { - acc, err := accountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - payload, err := payloadFromContext(r.Context()) + payload, err := payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -156,7 +158,7 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). acc.Status = uar.Status acc.Contact = uar.Contact - if err = h.db.UpdateAccount(r.Context(), acc); err != nil { + if err = h.db.UpdateAccount(ctx, acc); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) return } @@ -164,8 +166,7 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, - true, acc.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, true, acc.ID)) api.JSON(w, acc) } diff --git a/acme/api/linker.go b/acme/api/linker.go index dd3b4540..d07271e2 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -136,9 +136,9 @@ func (l LinkType) String() string { // LinkOrder sets the ACME links required by an ACME order. func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) { - o.azURLs = make([]string, len(o.AuthorizationIDs)) - for i, azID := range o.AutohrizationIDs { - o.azURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) + o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) + for i, azID := range o.AuthorizationIDs { + o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) } o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID) if o.CertificateID != "" { @@ -148,12 +148,12 @@ func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) { // LinkAccount sets the ACME links required by an ACME account. func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) { - a.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) + acc.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) } // LinkChallenge sets the ACME links required by an ACME account. func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { - a.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) + ch.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME account. diff --git a/acme/api/order.go b/acme/api/order.go index 2bf7d2ef..9fe0eb26 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -119,7 +119,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } - h.linker.Link(ctx, o) + h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSONStatus(w, o, http.StatusCreated) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 2f5ee11b..9e83e7ff 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -203,8 +203,8 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { case err != nil: return errors.Wrapf(err, "error storing order IDs for account %s", accID) case !swapped: - return ServerInternalErr(errors.Errorf("error storing order IDs "+ - "for account %s; order IDs changed since last read", accID)) + return errors.Errorf("error storing order IDs "+ + "for account %s; order IDs changed since last read", accID) default: return nil } From fc395f4d69c612b9f0687a4a0b7e110056fec719 Mon Sep 17 00:00:00 2001 From: max furman Date: Sat, 6 Mar 2021 13:06:43 -0800 Subject: [PATCH 13/47] [acme db interface] compiles! --- acme/api/account.go | 41 ++++++++------- acme/api/linker.go | 11 +++- acme/db.go | 3 +- acme/db/nosql/nonce.go | 2 +- acme/db/nosql/nosql.go | 9 ++-- acme/db/nosql/order.go | 112 ++++++++++++++--------------------------- acme/order.go | 2 +- ca/ca.go | 16 +++--- 8 files changed, 84 insertions(+), 112 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index 6f2a5f96..c06c034a 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "github.com/go-chi/chi" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/logging" @@ -181,24 +182,26 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { // GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { - /* - acc, err := acme.AccountFromContext(r.Context()) - if err != nil { - api.WriteError(w, err) - return - } - accID := chi.URLParam(r, "accID") - if acc.ID != accID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param")) - return - } - orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) - if err != nil { - api.WriteError(w, err) - return - } - api.JSON(w, orders) - logOrdersByAccount(w, orders) - */ + ctx := r.Context() + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + accID := chi.URLParam(r, "accID") + if acc.ID != accID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) + return + } + orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) + if err != nil { + api.WriteError(w, err) + return + } + + h.linker.LinkOrdersByAccountID(ctx, orders) + + api.JSON(w, orders) + logOrdersByAccount(w, orders) return } diff --git a/acme/api/linker.go b/acme/api/linker.go index d07271e2..b9215e06 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -151,14 +151,21 @@ func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) { acc.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) } -// LinkChallenge sets the ACME links required by an ACME account. +// LinkChallenge sets the ACME links required by an ACME challenge. func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { ch.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) } -// LinkAuthorization sets the ACME links required by an ACME account. +// LinkAuthorization sets the ACME links required by an ACME authorization. func (l *Linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch) } } + +// LinkOrdersByAccountID converts each order ID to an ACME link. +func (l *Linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { + for i, id := range orders { + orders[i] = l.GetLink(ctx, OrderLinkType, true, id) + } +} diff --git a/acme/db.go b/acme/db.go index dfbd30ce..a19621c0 100644 --- a/acme/db.go +++ b/acme/db.go @@ -24,8 +24,7 @@ type DB interface { UpdateChallenge(ctx context.Context, ch *Challenge) error CreateOrder(ctx context.Context, o *Order) error - DeleteOrder(ctx context.Context, id string) error GetOrder(ctx context.Context, id string) (*Order, error) - GetOrdersByAccountID(ctx context.Context, accountID string) error + GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) UpdateOrder(ctx context.Context, o *Order) error } diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index 02dcda6c..76f742b2 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -43,7 +43,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { // DeleteNonce verifies that the nonce is valid (by checking if it exists), // and if so, consumes the nonce resource by deleting it from the database. -func (db *DB) DeleteNonce(nonce string) error { +func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error { err := db.db.Update(&database.Tx{ Operations: []*database.TxEntry{ { diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index 0c040a89..bcb118d8 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -44,8 +44,7 @@ func New(db nosqlDB.DB) (*DB, error) { func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { newB, err := json.Marshal(nu) if err != nil { - return errors.Wrapf(err, - "error marshaling new acme %s", typ) + return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) } var oldB []byte if old == nil { @@ -53,8 +52,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface } else { oldB, err = json.Marshal(old) if err != nil { - return errors.Wrapf(err, - "error marshaling old acme %s", typ) + return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old) } } @@ -63,8 +61,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface case err != nil: return errors.Wrapf(err, "error saving acme %s", typ) case !swapped: - return errors.Errorf("error saving acme %s; "+ - "changed since last read", typ) + return errors.Errorf("error saving acme %s; changed since last read", typ) default: return nil } diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 9e83e7ff..59afc41c 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -29,8 +29,13 @@ type dbOrder struct { CertificateID string `json:"certificate,omitempty"` } +func (a *dbOrder) clone() *dbOrder { + b := *a + return &b +} + // getDBOrder retrieves and unmarshals an ACME Order type from the database. -func (db *DB) getDBOrder(id string) (*dbOrder, error) { +func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) { b, err := db.db.Get(orderTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id) @@ -46,7 +51,7 @@ func (db *DB) getDBOrder(id string) (*dbOrder, error) { // GetOrder retrieves an ACME Order from the database. func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { - dbo, err := db.getDBOrder(id) + dbo, err := db.getDBOrder(ctx, id) if err != nil { return nil, err } @@ -91,8 +96,7 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { return err } - var oidHelper = orderIDsByAccount{} - _, err = oidHelper.addOrderID(db, o.AccountID, o.ID) + _, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID) if err != nil { return err } @@ -104,28 +108,11 @@ type orderIDsByAccount struct{} // addOrderID adds an order ID to a users index of in progress order IDs. // This method will also cull any orders that are no longer in the `pending` // state from the index before returning it. -func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { +func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) { ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() - // Update the "order IDs by account ID" index - oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID) - if err != nil { - return nil, err - } - newOids := append(oids, oid) - if err = orderIDs(newOids).save(db, oids, accID); err != nil { - // Delete the entire order if storing the index fails. - db.Del(orderTable, []byte(oid)) - return nil, err - } - return newOids, nil -} - -// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the -// account. -func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { - b, err := db.Get(ordersByAccountIDTable, []byte(accID)) + b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) if err != nil { if nosql.IsErrNotFound(err) { return []string{}, nil @@ -145,67 +132,46 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri // that are invalid in the array of URLs. pendOids := []string{} for _, oid := range oids { - o, err := getOrder(db, oid) + o, err := db.GetOrder(ctx, oid) if err != nil { - return nil, errors.Wrapf(err, "error loading order %s for account %s", oid, accID) + return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID) } - if o, err = o.UpdateStatus(db); err != nil { - return nil, errors.Wrapf(err, "error updating order %s for account %s", oid, accID) + if err = o.UpdateStatus(ctx, db); err != nil { + return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID) } if o.Status == acme.StatusPending { pendOids = append(pendOids, oid) } } - // If the number of pending orders is less than the number of orders in the - // list, then update the pending order list. - if len(pendOids) != len(oids) { - if err = orderIDs(pendOids).save(db, oids, accID); err != nil { - return nil, errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids)) + pendOids = append(pendOids, addOids...) + if len(oids) == 0 { + oids = nil + } + if err = db.save(ctx, accID, pendOids, oids, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { + // Delete all orders that may have been previously stored if orderIDsByAccountID update fails. + for _, oid := range addOids { + db.db.Del(orderTable, []byte(oid)) } + return nil, errors.Wrap(err, "error saving OrderIDsByAccountID index") } - return pendOids, nil } -type orderIDs []string - -// save is used to update the list of orderIDs keyed by ACME account ID -// stored in the database. -// -// This method always converts empty lists to 'nil' when storing to the DB. We -// do this to avoid any confusion between an empty list and a nil value in the -// db. -func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { - var ( - err error - oldb []byte - newb []byte - ) - if len(old) == 0 { - oldb = nil - } else { - oldb, err = json.Marshal(old) - if err != nil { - return errors.Wrap(err, "error marshaling old order IDs slice") - } - } - if len(oids) == 0 { - newb = nil - } else { - newb, err = json.Marshal(oids) - if err != nil { - return errors.Wrap(err, "error marshaling new order IDs slice") - } - } - _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) - switch { - case err != nil: - return errors.Wrapf(err, "error storing order IDs for account %s", accID) - case !swapped: - return errors.Errorf("error storing order IDs "+ - "for account %s; order IDs changed since last read", accID) - default: - return nil +func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { + return db.updateAddOrderIDs(ctx, accID) +} + +// UpdateOrder saves an updated ACME Order to the database. +func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { + old, err := db.getDBOrder(ctx, o.ID) + if err != nil { + return err } + + nu := old.clone() + + nu.Status = o.Status + nu.Error = o.Error + nu.CertificateID = o.CertificateID + return db.save(ctx, old.ID, nu, old, "order", orderTable) } diff --git a/acme/order.go b/acme/order.go index 1719d899..7b0b2d4d 100644 --- a/acme/order.go +++ b/acme/order.go @@ -25,7 +25,7 @@ type Order struct { Identifiers []Identifier `json:"identifiers"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` - Error interface{} `json:"error,omitempty"` + Error *Error `json:"error,omitempty"` AuthorizationIDs []string `json:"-"` AuthorizationURLs []string `json:"authorizations"` FinalizeURL string `json:"finalize"` diff --git a/ca/ca.go b/ca/ca.go index 5ebc0919..43cbf0ba 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -124,24 +124,24 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } prefix := "acme" - acmeAuth, err := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ + acmeDB, err := acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) + if err != nil { + return nil, errors.Wrap(err, "error configuring ACME DB interface") + } + acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ Backdate: *config.AuthorityConfig.Backdate, - DB: acmeNoSQL.New(auth.GetDatabase().(nosql.DB)), + DB: acmeDB, DNS: dns, Prefix: prefix, CA: auth, }) - if err != nil { - return nil, errors.Wrap(err, "error creating ACME authority") - } - acmeRouterHandler := acmeAPI.New(acmeAuth) mux.Route("/"+prefix, func(r chi.Router) { - acmeRouterHandler.Route(r) + acmeHandler.Route(r) }) // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 // of the ACME spec. mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeRouterHandler.Route(r) + acmeHandler.Route(r) }) /* From f20fcae80e990baa03d1b4169b52b96a36cadde9 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 8 Mar 2021 22:35:57 -0800 Subject: [PATCH 14/47] [acme db interface] wip unit test fixing --- acme/api/account.go | 4 +- acme/api/account_test.go | 451 +++++++++++++------------------ acme/api/handler.go | 8 +- acme/api/handler_test.go | 511 ++++++++++-------------------------- acme/api/linker_test.go | 99 +++++++ acme/api/middleware_test.go | 477 +++++++++++++++++---------------- acme/db.go | 215 ++++++++++++++- acme/directory_test.go | 99 ------- 8 files changed, 887 insertions(+), 977 deletions(-) create mode 100644 acme/api/linker_test.go delete mode 100644 acme/directory_test.go diff --git a/acme/api/account.go b/acme/api/account.go index c06c034a..30d406e4 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -180,8 +180,8 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } } -// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { +// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. +func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/account_test.go b/acme/api/account_test.go index bdd61c59..d94819c7 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" @@ -53,7 +52,7 @@ func TestNewAccountRequestValidate(t *testing.T) { OnlyReturnExisting: true, Contact: []string{"foo", "bar"}, }, - err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")), + err: acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone"), } }, "fail/bad-contact": func(t *testing.T) test { @@ -61,7 +60,7 @@ func TestNewAccountRequestValidate(t *testing.T) { nar: &NewAccountRequest{ Contact: []string{"foo", ""}, }, - err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "ok": func(t *testing.T) test { @@ -109,8 +108,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) { Contact: []string{"foo", "bar"}, Status: "foo", }, - err: acme.MalformedErr(errors.Errorf("incompatible input; " + - "contact and status updates are mutually exclusive")), + err: acme.NewError(acme.ErrorMalformedType, "incompatible input; "+ + "contact and status updates are mutually exclusive"), } }, "fail/bad-contact": func(t *testing.T) test { @@ -118,7 +117,7 @@ func TestUpdateAccountRequestValidate(t *testing.T) { uar: &UpdateAccountRequest{ Contact: []string{"foo", ""}, }, - err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/bad-status": func(t *testing.T) test { @@ -126,8 +125,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) { uar: &UpdateAccountRequest{ Status: "foo", }, - err: acme.MalformedErr(errors.Errorf("cannot update account " + - "status to foo, only deactivated")), + err: acme.NewError(acme.ErrorMalformedType, "cannot update account "+ + "status to foo, only deactivated"), } }, "ok/contact": func(t *testing.T) test { @@ -168,13 +167,12 @@ func TestUpdateAccountRequestValidate(t *testing.T) { } } -func TestHandlerGetOrdersByAccount(t *testing.T) { +func TestHandler_GetOrdersByAccountID(t *testing.T) { oids := []string{ "https://ca.smallstep.com/acme/order/foo", "https://ca.smallstep.com/acme/order/bar", } accID := "account-id" - prov := newProv() // Request with chi context chiCtx := chi.NewRouteContext() @@ -182,67 +180,59 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "foo"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"), } }, "fail/getOrdersByAccount-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) + db: &acme.MockDB{ + MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { assert.Equals(t, id, acc.ID) return oids, nil }, @@ -255,11 +245,11 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetOrdersByAccount(w, req) + h.GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -268,15 +258,14 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(oids) @@ -288,7 +277,7 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { } } -func TestHandlerNewAccount(t *testing.T) { +func TestHandler_NewAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ ID: accID, @@ -300,35 +289,34 @@ func TestHandlerNewAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to "+ + "unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -337,12 +325,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/no-existing-account": func(t *testing.T) test { @@ -351,12 +338,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-jwk": func(t *testing.T) test { @@ -365,12 +351,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { @@ -379,13 +364,12 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, nil) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/NewAccount-error": func(t *testing.T) test { @@ -396,23 +380,19 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - return test{ - auth: &mockAcmeAuthority{ - newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.Contact, nar.Contact) - assert.Equals(t, ops.Key, jwk) - return nil, acme.ServerInternalErr(errors.New("force")) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + return test{ + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok/new-account": func(t *testing.T) test { @@ -423,28 +403,27 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - return test{ - auth: &mockAcmeAuthority{ - newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.Contact, nar.Contact) - assert.Equals(t, ops.Key, jwk) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return nil }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.True(t, abs) + assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + */ }, ctx: ctx, statusCode: 201, @@ -456,21 +435,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - }, + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ ctx: ctx, statusCode: 200, } @@ -479,7 +448,7 @@ func TestHandlerNewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -492,15 +461,14 @@ func TestHandlerNewAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(acc) @@ -527,55 +495,51 @@ func TestHandlerGetUpdateAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -584,62 +548,33 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, - "fail/Deactivate-error": func(t *testing.T) test { + "fail/update-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - return test{ - auth: &mockAcmeAuthority{ - deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - return nil, acme.ServerInternalErr(errors.New("force")) - }, - }, - ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), - } - }, - "fail/UpdateAccount-error": func(t *testing.T) test { - uar := &UpdateAccountRequest{ - Contact: []string{"foo", "bar"}, - } - b, err := json.Marshal(uar) - assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - return test{ - auth: &mockAcmeAuthority{ - updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - assert.Equals(t, contacts, uar.Contact) - return nil, acme.ServerInternalErr(errors.New("force")) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Status, acme.StatusDeactivated) + assert.Equals(t, upd.ID, acc.ID) + return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok/deactivate": func(t *testing.T) test { @@ -648,27 +583,27 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - return test{ - auth: &mockAcmeAuthority{ - deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Status, acme.StatusDeactivated) + assert.Equals(t, upd.ID, acc.ID) + return nil }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + */ }, ctx: ctx, statusCode: 200, @@ -678,21 +613,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - }, + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ ctx: ctx, statusCode: 200, } @@ -703,49 +628,50 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - return test{ - auth: &mockAcmeAuthority{ - updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - assert.Equals(t, contacts, uar.Contact) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Contact, uar.Contact) + assert.Equals(t, upd.ID, acc.ID) + return nil }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + */ }, ctx: ctx, statusCode: 200, } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL, provName, accID) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + /* + auth: &mockAcmeAuthority{ + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL, provName, accID) + }, }, - }, + */ ctx: ctx, statusCode: 200, } @@ -754,7 +680,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -767,15 +693,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(acc) diff --git a/acme/api/handler.go b/acme/api/handler.go index 997456a7..31466c6c 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -90,9 +90,9 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) + r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) + r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } @@ -149,8 +149,8 @@ func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } -// GetAuthz ACME api for retrieving an Authz. -func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { +// GetAuthorization ACME api for retrieving an Authz. +func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 7e19ea75..8a5ac694 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,7 +3,6 @@ package api import ( "bytes" "context" - "crypto/x509" "encoding/json" "encoding/pem" "fmt" @@ -14,209 +13,13 @@ import ( "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) -type mockAcmeAuthority struct { - getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string - getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string - - deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error) - getAccount func(ctx context.Context, accID string) (*acme.Account, error) - getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error) - newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error) - updateAccount func(context.Context, string, []string) (*acme.Account, error) - - getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error) - validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error) - getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error) - getDirectory func(ctx context.Context) (*acme.Directory, error) - getCertificate func(string, string) ([]byte, error) - - finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error) - getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error) - getOrdersByAccount func(ctx context.Context, accID string) ([]string, error) - newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error) - - loadProvisionerByID func(string) (provisioner.Interface, error) - newNonce func() (string, error) - useNonce func(string) error - ret1 interface{} - err error -} - -func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) { - if m.deactivateAccount != nil { - return m.deactivateAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { - if m.finalizeOrder != nil { - return m.finalizeOrder(ctx, accID, id, csr) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) { - if m.getAccount != nil { - return m.getAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - if m.getAccountByKey != nil { - return m.getAccountByKey(ctx, jwk) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) { - if m.getAuthz != nil { - return m.getAuthz(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Authz), m.err -} - -func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) { - if m.getCertificate != nil { - return m.getCertificate(accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.([]byte), m.err -} - -func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) { - if m.getChallenge != nil { - return m.getChallenge(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Challenge), m.err -} - -func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) { - if m.getDirectory != nil { - return m.getDirectory(ctx) - } - return m.ret1.(*acme.Directory), m.err -} - -func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - if m.getLink != nil { - return m.getLink(ctx, typ, abs, ins...) - } - return m.ret1.(string) -} - -func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string { - if m.getLinkExplicit != nil { - return m.getLinkExplicit(typ, provID, abs, baseURL, ins...) - } - return m.ret1.(string) -} - -func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) { - if m.getOrder != nil { - return m.getOrder(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - if m.getOrdersByAccount != nil { - return m.getOrdersByAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.([]string), m.err -} - -func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { - if m.loadProvisionerByID != nil { - return m.loadProvisionerByID(provID) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - if m.newAccount != nil { - return m.newAccount(ctx, ops) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) NewNonce() (string, error) { - if m.newNonce != nil { - return m.newNonce() - } else if m.err != nil { - return "", m.err - } - return m.ret1.(string), m.err -} - -func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - if m.newOrder != nil { - return m.newOrder(ctx, ops) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) { - if m.updateAccount != nil { - return m.updateAccount(ctx, id, contact) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) UseNonce(nonce string) error { - if m.useNonce != nil { - return m.useNonce(nonce) - } - return m.err -} - -func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { - switch { - case m.validateChallenge != nil: - return m.validateChallenge(ctx, accID, id, jwk) - case m.err != nil: - return nil, m.err - default: - return m.ret1.(*acme.Challenge), m.err - } -} - -func TestHandlerGetNonce(t *testing.T) { +func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string statusCode int @@ -230,7 +33,7 @@ func TestHandlerGetNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name h.GetNonce(w, req) @@ -243,21 +46,16 @@ func TestHandlerGetNonce(t *testing.T) { } } -func TestHandlerGetDirectory(t *testing.T) { - auth, err := acme.New(nil, acme.AuthorityOptions{ - DB: new(db.MockNoSQLDB), - DNS: "ca.smallstep.com", - Prefix: "acme", - }) - assert.FatalError(t, err) +func TestHandler_GetDirectory(t *testing.T) { + linker := NewLinker("acme", "ca.smallstep.com") prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - expDir := acme.Directory{ + expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), @@ -267,7 +65,7 @@ func TestHandlerGetDirectory(t *testing.T) { type test struct { statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { @@ -279,7 +77,7 @@ func TestHandlerGetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(auth).(*Handler) + h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -292,18 +90,17 @@ func TestHandlerGetDirectory(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - var dir acme.Directory + var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) assert.Equals(t, dir, expDir) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -312,16 +109,16 @@ func TestHandlerGetDirectory(t *testing.T) { } } -func TestHandlerGetAuthz(t *testing.T) { +func TestHandler_GetAuthz(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) - az := acme.Authz{ + az := acme.Authorization{ ID: "authzID", Identifier: acme.Identifier{ Type: "dns", Value: "example.com", }, Status: "pending", - Expires: expiry.Format(time.RFC3339), + Expires: expiry, Wildcard: false, Challenges: []*acme.Challenge{ { @@ -353,67 +150,64 @@ func TestHandlerGetAuthz(t *testing.T) { baseURL.String(), provName, az.ID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/getAuthz-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { assert.Equals(t, id, az.ID) return &az, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.True(t, abs) - assert.Equals(t, in, []string{az.ID}) - return url - }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AuthzLink) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.True(t, abs) + assert.Equals(t, in, []string{az.ID}) + return url + }, + */ }, ctx: ctx, statusCode: 200, @@ -423,11 +217,11 @@ func TestHandlerGetAuthz(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAuthz(w, req) + h.GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -436,15 +230,14 @@ func TestHandlerGetAuthz(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { //var gotAz acme.Authz @@ -459,7 +252,7 @@ func TestHandlerGetAuthz(t *testing.T) { } } -func TestHandlerGetCertificate(t *testing.T) { +func TestHandler_GetCertificate(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt") @@ -490,89 +283,83 @@ func TestHandlerGetCertificate(t *testing.T) { baseURL.String(), provName, certID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/getCertificate-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/decode-leaf-for-loggger": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return []byte("foo"), nil + return &acme.Certificate{}, nil }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")), + err: acme.NewErrorISE("failed to decode any certificates from generated certBytes"), } }, "fail/parse-x509-leaf-for-logger": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: []byte("foo"), - }), nil + return &acme.Certificate{}, nil }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to parse generated leaf certificate")), + err: acme.NewErrorISE("failed to parse generated leaf certificate"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return certBytes, nil + return &acme.Certificate{}, nil }, }, ctx: ctx, @@ -583,7 +370,7 @@ func TestHandlerGetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -596,15 +383,14 @@ func TestHandlerGetCertificate(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.HasPrefix(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.HasPrefix(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes)) @@ -634,121 +420,115 @@ func TestHandlerGetChallenge(t *testing.T) { url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID") type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int ch acme.Challenge - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ - ctx: ctx, + ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.UnauthorizedErr(nil), + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(nil), + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), } }, "fail/get-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.UnauthorizedErr(nil), + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(nil), + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), } }, "ok/validate-challenge": func(t *testing.T) test { key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ch := ch() ch.Status = "valid" ch.Validated = time.Now().UTC().Format(time.RFC3339) count := 0 return test{ - auth: &mockAcmeAuthority{ - validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, ch.ID) - assert.Equals(t, jwk.KeyID, key.KeyID) + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, ch.ID) return &ch, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - var ret string - switch count { - case 0: - assert.Equals(t, typ, acme.AuthzLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) - case 1: - assert.Equals(t, typ, acme.ChallengeLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.ID}) - ret = url - } - count++ - return ret - }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + var ret string + switch count { + case 0: + assert.Equals(t, typ, acme.AuthzLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.AuthzID}) + ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) + case 1: + assert.Equals(t, typ, acme.ChallengeLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.ID}) + ret = url + } + count++ + return ret + }, + */ }, ctx: ctx, statusCode: 200, @@ -759,7 +539,7 @@ func TestHandlerGetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -772,15 +552,14 @@ func TestHandlerGetChallenge(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(tc.ch) diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go new file mode 100644 index 00000000..ab1ad3ba --- /dev/null +++ b/acme/api/linker_test.go @@ -0,0 +1,99 @@ +package api + +import ( + "context" + "fmt" + "net/url" + "testing" + + "github.com/smallstep/assert" +) + +func TestLinkerGetLink(t *testing.T) { + dns := "ca.smallstep.com" + prefix := "acme" + linker := NewLinker(dns, prefix) + id := "1234" + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, true), + fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) + + // No provisioner + ctxNoProv := context.WithValue(context.Background(), baseURLContextKey, baseURL) + assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, true), + fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) + assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, false), "//new-nonce") + + // No baseURL + ctxNoBaseURL := context.WithValue(context.Background(), provisionerContextKey, prov) + assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, true), + fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) + assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) + + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, true, id), + fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName)) +} + +func TestLinkerGetLinkExplicit(t *testing.T) { + dns := "ca.smallstep.com" + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prefix := "acme" + linker := NewLinker(dns, prefix) + id := "1234" + + prov := newProv() + provID := url.PathEscape(prov.GetName()) + + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) + + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) + + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) + + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) + + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) + + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) + + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) + + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) + + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) + + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) + + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) + + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) + + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) + + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) +} diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index d2a9cdc0..750b019d 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -82,13 +82,13 @@ func Test_baseURLFromRequest(t *testing.T) { } func TestHandlerBaseURLFromRequest(t *testing.T) { - h := New(&mockAcmeAuthority{}).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req.Host = "test.ca.smallstep.com:8080" w := httptest.NewRecorder() next := func(w http.ResponseWriter, r *http.Request) { - bu := acme.BaseURLFromContext(r.Context()) + bu := baseURLFromContext(r.Context()) if assert.NotNil(t, bu) { assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") assert.Equals(t, bu.Scheme, "https") @@ -101,35 +101,35 @@ func TestHandlerBaseURLFromRequest(t *testing.T) { req.Host = "" next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, acme.BaseURLFromContext(r.Context()), nil) + assert.Equals(t, baseURLFromContext(r.Context()), nil) } h.baseURLFromRequest(next)(w, req) } -func TestHandlerAddNonce(t *testing.T) { +func TestHandler_AddNonce(t *testing.T) { url := "https://ca.smallstep.com/acme/new-nonce" type test struct { - auth acme.Interface - problem *acme.Error + db acme.DB + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/AddNonce-error": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{ - newNonce: func() (string, error) { - return "", acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { + return acme.Nonce(""), acme.NewErrorISE("force") }, }, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{ - newNonce: func() (string, error) { + db: &acme.MockDB{ + MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { return "bar", nil }, }, @@ -140,7 +140,7 @@ func TestHandlerAddNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) w := httptest.NewRecorder() h.addNonce(testNext)(w, req) @@ -152,15 +152,14 @@ func TestHandlerAddNonce(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Replay-Nonce"], []string{"bar"}) @@ -176,22 +175,24 @@ func TestHandlerAddDirLink(t *testing.T) { provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB link string statusCode int ctx context.Context - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) - }, + db: &acme.MockDB{ + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, baseURLFromContext(ctx), baseURL) + return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) + }, + */ }, ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), @@ -202,7 +203,7 @@ func TestHandlerAddDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -215,15 +216,14 @@ func TestHandlerAddDirLink(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s>;rel=\"index\"", tc.link)}) @@ -242,7 +242,7 @@ func TestHandlerVerifyContentType(t *testing.T) { h Handler ctx context.Context contentType string - problem *acme.Error + err *acme.Error statusCode int url string } @@ -250,7 +250,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -260,16 +260,16 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, - problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")), + err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), } }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -278,16 +278,16 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, - problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")), + err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), } }, "ok": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -296,7 +296,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -304,7 +304,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -313,7 +313,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkix-cert", statusCode: 200, } @@ -321,7 +321,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/jose+json": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -330,7 +330,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -338,7 +338,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -347,7 +347,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -373,15 +373,14 @@ func TestHandlerVerifyContentType(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -394,7 +393,7 @@ func TestHandlerIsPostAsGet(t *testing.T) { url := "https://ca.smallstep.com/acme/new-account" type test struct { ctx context.Context - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -402,26 +401,26 @@ func TestHandlerIsPostAsGet(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, nil), + ctx: context.WithValue(context.Background(), payloadContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/not-post-as-get": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{}), + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), statusCode: 400, - problem: acme.MalformedErr(errors.New("expected POST-as-GET")), + err: acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"), } }, "ok": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}), + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), statusCode: 200, } }, @@ -429,7 +428,7 @@ func TestHandlerIsPostAsGet(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -442,15 +441,14 @@ func TestHandlerIsPostAsGet(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -473,7 +471,7 @@ func TestHandlerParseJWS(t *testing.T) { type test struct { next nextHTTP body io.Reader - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -481,14 +479,14 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: errReader(0), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to read request body: force")), + err: acme.NewErrorISE("failed to read request body: force"), } }, "fail/parse-jws-error": func(t *testing.T) test { return test{ body: strings.NewReader("foo"), statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts")), + err: acme.NewError(acme.ErrorMalformedType, "failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts"), } }, "ok": func(t *testing.T) test { @@ -507,7 +505,7 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: strings.NewReader(expRaw), next: func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + jws, err := jwsFromContext(r.Context()) assert.FatalError(t, err) gotRaw, err := jws.CompactSerialize() assert.FatalError(t, err) @@ -521,7 +519,7 @@ func TestHandlerParseJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, tc.body) w := httptest.NewRecorder() h.parseJWS(tc.next)(w, req) @@ -533,15 +531,14 @@ func TestHandlerParseJWS(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -572,7 +569,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { type test struct { ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -580,58 +577,58 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), + ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-jwk": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) return test{ - ctx: context.WithValue(ctx, acme.JwkContextKey, nil), + ctx: context.WithValue(ctx, jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/verify-jws-failure": func(t *testing.T) test { _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, &_pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("error verifying jws: square/go-jose: error in cryptographic primitive")), + err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: square/go-jose: error in cryptographic primitive"), } }, "fail/algorithm-mismatch": func(t *testing.T) test { _pub := *pub clone := &_pub clone.Algorithm = jose.HS256 - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, clone) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, clone) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("verifier and signature algorithm do not match")), + err: acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"), } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -651,8 +648,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _pub := *pub clone := &_pub clone.Algorithm = "" - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -675,8 +672,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -699,8 +696,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -720,7 +717,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -733,15 +730,14 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -775,27 +771,27 @@ func TestHandlerLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-kid": func(t *testing.T) test { @@ -806,11 +802,11 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) assert.True(t, abs) @@ -820,7 +816,7 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got ", prefix)), + err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), } }, "fail/bad-kid-prefix": func(t *testing.T) test { @@ -837,11 +833,11 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _parsed) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) assert.True(t, abs) @@ -851,15 +847,15 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got foo", prefix)), + err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -876,21 +872,21 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.AccountDoesNotExistErr(nil), } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, accID) - return nil, acme.ServerInternalErr(errors.New("force")) + return nil, acme.NewErrorISE("force") }, getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) @@ -901,16 +897,16 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -927,16 +923,16 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.UnauthorizedErr(errors.New("account is not active")), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -981,15 +977,14 @@ func TestHandlerLookupJWK(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.AError assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -1024,27 +1019,27 @@ func TestHandlerExtractJWK(t *testing.T) { url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provName) type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { @@ -1057,8 +1052,8 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, @@ -1075,39 +1070,39 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("invalid jwk in protected header")), + err: acme.MalformedErr(errors.New("invalid jwk in protected header")), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, jwk.KeyID, pub.KeyID) - return nil, acme.ServerInternalErr(errors.New("force")) + return nil, acme.NewErrorISE("force") }, }, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -1117,16 +1112,16 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.UnauthorizedErr(errors.New("account is not active")), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -1148,11 +1143,11 @@ func TestHandlerExtractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -1190,15 +1185,14 @@ func TestHandlerExtractJWK(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.AError assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -1210,10 +1204,10 @@ func TestHandlerExtractJWK(t *testing.T) { func TestHandlerValidateJWS(t *testing.T) { url := "https://ca.smallstep.com/acme/account/1234" type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1221,21 +1215,21 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), + ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-signature": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, &jose.JSONWebSignature{}), + ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, - problem: acme.MalformedErr(errors.New("request body does not contain a signature")), + err: acme.MalformedErr(errors.New("request body does not contain a signature")), } }, "fail/more-than-one-signature": func(t *testing.T) test { @@ -1246,9 +1240,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("request body contains more than one signature")), + err: acme.MalformedErr(errors.New("request body contains more than one signature")), } }, "fail/unprotected-header-not-empty": func(t *testing.T) test { @@ -1258,9 +1252,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("unprotected header must not be used")), + err: acme.MalformedErr(errors.New("unprotected header must not be used")), } }, "fail/unsuitable-algorithm-none": func(t *testing.T) test { @@ -1270,9 +1264,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")), + err: acme.MalformedErr(errors.New("unsuitable algorithm: none")), } }, "fail/unsuitable-algorithm-mac": func(t *testing.T) test { @@ -1282,9 +1276,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), + err: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), } }, "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { @@ -1305,14 +1299,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), + err: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), } }, "fail/rsa-key-too-small": func(t *testing.T) test { @@ -1333,14 +1327,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), + err: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), } }, "fail/UseNonce-error": func(t *testing.T) test { @@ -1350,14 +1344,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { - return acme.ServerInternalErr(errors.New("force")) + return acme.NewErrorISE("force") }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/no-url-header": func(t *testing.T) test { @@ -1367,14 +1361,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("jws missing url protected header")), + err: acme.MalformedErr(errors.New("jws missing url protected header")), } }, "fail/url-mismatch": func(t *testing.T) test { @@ -1391,14 +1385,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), + err: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), } }, "fail/both-jwk-kid": func(t *testing.T) test { @@ -1420,14 +1414,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), + err: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), } }, "fail/no-jwk-kid": func(t *testing.T) test { @@ -1444,14 +1438,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), + err: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), } }, "ok/kid": func(t *testing.T) test { @@ -1469,12 +1463,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1499,12 +1493,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1529,12 +1523,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1558,15 +1552,14 @@ func TestHandlerValidateJWS(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.AError assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) diff --git a/acme/db.go b/acme/db.go index a19621c0..dcc7846f 100644 --- a/acme/db.go +++ b/acme/db.go @@ -1,6 +1,8 @@ package acme -import "context" +import ( + "context" +) // DB is the DB interface expected by the step-ca ACME API. type DB interface { @@ -28,3 +30,214 @@ type DB interface { GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) UpdateOrder(ctx context.Context, o *Order) error } + +// MockDB is an implementation of the DB interface that should only be used as +// a mock in tests. +type MockDB struct { + MockCreateAccount func(ctx context.Context, acc *Account) error + MockGetAccount func(ctx context.Context, id string) (*Account, error) + MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error) + MockUpdateAccount func(ctx context.Context, acc *Account) error + + MockCreateNonce func(ctx context.Context) (Nonce, error) + MockDeleteNonce func(ctx context.Context, nonce Nonce) error + + MockCreateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) + MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + + MockCreateCertificate func(ctx context.Context, cert *Certificate) error + MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) + + MockCreateChallenge func(ctx context.Context, ch *Challenge) error + MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error) + MockUpdateChallenge func(ctx context.Context, ch *Challenge) error + + MockCreateOrder func(ctx context.Context, o *Order) error + MockGetOrder func(ctx context.Context, id string) (*Order, error) + MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) + MockUpdateOrder func(ctx context.Context, o *Order) error + + MockRet1 interface{} + MockError error +} + +// CreateAccount mock. +func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error { + if m.MockCreateAccount != nil { + return m.MockCreateAccount(ctx, acc) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetAccount mock. +func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) { + if m.MockGetAccount != nil { + return m.MockGetAccount(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Account), m.MockError +} + +// GetAccountByKeyID mock +func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { + if m.MockGetAccountByKeyID != nil { + return m.MockGetAccountByKeyID(ctx, kid) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Account), m.MockError +} + +// UpdateAccount mock +func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error { + if m.MockUpdateAccount != nil { + return m.MockUpdateAccount(ctx, acc) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateNonce mock +func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) { + if m.MockCreateNonce != nil { + return m.MockCreateNonce(ctx) + } else if m.MockError != nil { + return Nonce(""), m.MockError + } + return m.MockRet1.(Nonce), m.MockError +} + +// DeleteNonce mock +func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error { + if m.MockDeleteNonce != nil { + return m.MockDeleteNonce(ctx, nonce) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateAuthorization mock +func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error { + if m.MockCreateAuthorization != nil { + return m.MockCreateAuthorization(ctx, az) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetAuthorization mock +func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) { + if m.MockGetAuthorization != nil { + return m.MockGetAuthorization(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Authorization), m.MockError +} + +// UpdateAuthorization mock +func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error { + if m.MockUpdateAuthorization != nil { + return m.MockUpdateAuthorization(ctx, az) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateCertificate mock +func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { + if m.MockCreateCertificate != nil { + return m.MockCreateCertificate(ctx, cert) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetCertificate mock +func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { + if m.MockGetCertificate != nil { + return m.MockGetCertificate(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Certificate), m.MockError +} + +// CreateChallenge mock +func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error { + if m.MockCreateChallenge != nil { + return m.MockCreateChallenge(ctx, ch) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetChallenge mock +func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) { + if m.MockGetChallenge != nil { + return m.MockGetChallenge(ctx, chID, azID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Challenge), m.MockError +} + +// UpdateChallenge mock +func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error { + if m.MockUpdateChallenge != nil { + return m.MockUpdateChallenge(ctx, ch) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateOrder mock +func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error { + if m.MockCreateOrder != nil { + return m.MockCreateOrder(ctx, o) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetOrder mock +func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) { + if m.MockGetOrder != nil { + return m.MockGetOrder(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Order), m.MockError +} + +// UpdateOrder mock +func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error { + if m.MockUpdateOrder != nil { + return m.MockUpdateOrder(ctx, o) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetOrdersByAccountID mock +func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { + if m.MockGetOrdersByAccountID != nil { + return m.MockGetOrdersByAccountID(ctx, accID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.([]string), m.MockError +} diff --git a/acme/directory_test.go b/acme/directory_test.go deleted file mode 100644 index dd4c534c..00000000 --- a/acme/directory_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package acme - -import ( - "context" - "fmt" - "net/url" - "testing" - - "github.com/smallstep/assert" -) - -func TestDirectoryGetLink(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - dir := newDirectory(dns, prefix) - id := "1234" - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - assert.Equals(t, dir.getLink(ctx, NewNonceLink, true), - fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) - assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - - // No provisioner - ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL) - assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true), - fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) - assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce") - - // No baseURL - ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov) - assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true), - fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) - assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - - assert.Equals(t, dir.getLink(ctx, OrderLink, true, id), - fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) - assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName)) -} - -func TestDirectoryGetLinkExplicit(t *testing.T) { - dns := "ca.smallstep.com" - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prefix := "acme" - dir := newDirectory(dns, prefix) - id := "1234" - - prov := newProv() - provID := url.PathEscape(prov.GetName()) - - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) - - assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) - - assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) - - assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) - - assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) - - assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) - - assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) - - assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) -} From bb8d54e596b4ebed16953c19944b3fb67e9515df Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 10 Mar 2021 10:50:51 -0800 Subject: [PATCH 15/47] [acme db interface] unit tests compiling --- acme/account_test.go | 22 +-- acme/api/handler.go | 2 +- acme/api/handler_test.go | 1 - acme/api/linker.go | 120 ++++++++++-- acme/api/middleware_test.go | 248 +++++++----------------- acme/api/order_test.go | 373 ++++++++++++++++-------------------- acme/authority_test.go | 22 +-- acme/authz_test.go | 18 +- acme/certificate_test.go | 19 +- acme/challenge_test.go | 43 +---- acme/nonce_test.go | 13 +- acme/order_test.go | 21 +- ca/acmeClient_test.go | 90 ++++----- 13 files changed, 413 insertions(+), 579 deletions(-) diff --git a/acme/account_test.go b/acme/account_test.go index 2e072af5..45b86f20 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -1,20 +1,10 @@ package acme import ( - "context" - "encoding/json" "fmt" - "net/url" - "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" - "go.step.sm/crypto/jose" ) var ( @@ -39,7 +29,8 @@ func newProv() Provisioner { return p } -func newAcc() (*account, error) { +/* +func newAcc() (*Account, error) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) if err != nil { return nil, err @@ -53,12 +44,14 @@ func newAcc() (*account, error) { Key: jwk, Contact: []string{"foo", "bar"}, }) } +*/ +/* func TestGetAccountByID(t *testing.T) { type test struct { id string db nosql.DB - acc *account + acc *Account err *Error } tests := map[string]func(t *testing.T) test{ @@ -73,7 +66,7 @@ func TestGetAccountByID(t *testing.T) { return nil, database.ErrNotFound }, }, - err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)), + err: NewError(ErrorMalformedType, "account %s not found: not found", acc.ID), } }, "fail/db-error": func(t *testing.T) test { @@ -87,7 +80,7 @@ func TestGetAccountByID(t *testing.T) { return nil, errors.New("force") }, }, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)), + err: NewErrorISE("error loading account %s: force", acc.ID), } }, "fail/unmarshal-error": func(t *testing.T) test { @@ -768,3 +761,4 @@ func TestNewAccount(t *testing.T) { }) } } +*/ diff --git a/acme/api/handler.go b/acme/api/handler.go index 31466c6c..a6d35bb3 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -39,7 +39,7 @@ type Handler struct { db acme.DB backdate provisioner.Duration ca acme.CertificateAuthority - linker *Linker + linker Linker } // HandlerOptions required to create a new ACME API request handler. diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 8a5ac694..23db72a5 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -503,7 +503,6 @@ func TestHandlerGetChallenge(t *testing.T) { ch := ch() ch.Status = "valid" ch.Validated = time.Now().UTC().Format(time.RFC3339) - count := 0 return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { diff --git a/acme/api/linker.go b/acme/api/linker.go index b9215e06..6688732d 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -9,18 +9,30 @@ import ( ) // NewLinker returns a new Directory type. -func NewLinker(dns, prefix string) *Linker { - return &Linker{Prefix: prefix, DNS: dns} +func NewLinker(dns, prefix string) Linker { + return &linker{prefix: prefix, dns: dns} } -// Linker generates ACME links. -type Linker struct { - Prefix string - DNS string +// Linker interface for generating links for ACME resources. +type Linker interface { + GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string + GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string + + LinkOrder(ctx context.Context, o *acme.Order) + LinkAccount(ctx context.Context, o *acme.Account) + LinkChallenge(ctx context.Context, o *acme.Challenge) + LinkAuthorization(ctx context.Context, o *acme.Authorization) + LinkOrdersByAccountID(ctx context.Context, orders []string) +} + +// linker generates ACME links. +type linker struct { + prefix string + dns string } // GetLink is a helper for GetLinkExplicit -func (l *Linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { +func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { var provName string if p, err := provisionerFromContext(ctx); err == nil && p != nil { provName = p.GetName() @@ -31,7 +43,7 @@ func (l *Linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ... // GetLinkExplicit returns an absolute or partial path to the given resource and a base // URL dynamically obtained from the request for which the link is being // calculated. -func (l *Linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { +func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { var link string switch typ { case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: @@ -60,10 +72,10 @@ func (l *Linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, // If no Host is set, then use the default (first DNS attr in the ca.json). if u.Host == "" { - u.Host = l.DNS + u.Host = l.dns } - u.Path = l.Prefix + link + u.Path = l.prefix + link return u.String() } return link @@ -135,7 +147,7 @@ func (l LinkType) String() string { } // LinkOrder sets the ACME links required by an ACME order. -func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) { +func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) @@ -147,25 +159,103 @@ func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) { } // LinkAccount sets the ACME links required by an ACME account. -func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) { +func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { acc.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. -func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { +func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { ch.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. -func (l *Linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { +func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch) } } // LinkOrdersByAccountID converts each order ID to an ACME link. -func (l *Linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { +func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { for i, id := range orders { orders[i] = l.GetLink(ctx, OrderLinkType, true, id) } } + +// MockLinker implements the Linker interface. Only used for testing. +type MockLinker struct { + MockGetLink func(ctx context.Context, typ LinkType, abs bool, inputs ...string) string + MockGetLinkExplicit func(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string + + MockLinkOrder func(ctx context.Context, o *acme.Order) + MockLinkAccount func(ctx context.Context, o *acme.Account) + MockLinkChallenge func(ctx context.Context, o *acme.Challenge) + MockLinkAuthorization func(ctx context.Context, o *acme.Authorization) + MockLinkOrdersByAccountID func(ctx context.Context, orders []string) + + MockError error + MockRet1 interface{} +} + +// GetLink mock. +func (m *MockLinker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { + if m.MockGetLink != nil { + return m.MockGetLink(ctx, typ, abs, inputs...) + } + + return m.MockRet1.(string) +} + +// GetLinkExplicit mock. +func (m *MockLinker) GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string { + if m.MockGetLinkExplicit != nil { + return m.MockGetLinkExplicit(typ, provName, abs, baseURL, inputs...) + } + + return m.MockRet1.(string) +} + +// LinkOrder mock. +func (m *MockLinker) LinkOrder(ctx context.Context, o *acme.Order) { + if m.MockLinkOrder != nil { + m.MockLinkOrder(ctx, o) + return + } + return +} + +// LinkAccount mock. +func (m *MockLinker) LinkAccount(ctx context.Context, o *acme.Account) { + if m.MockLinkAccount != nil { + m.MockLinkAccount(ctx, o) + return + } + return +} + +// LinkChallenge mock. +func (m *MockLinker) LinkChallenge(ctx context.Context, o *acme.Challenge) { + if m.MockLinkChallenge != nil { + m.MockLinkChallenge(ctx, o) + return + } + return +} + +// LinkAuthorization mock. +func (m *MockLinker) LinkAuthorization(ctx context.Context, o *acme.Authorization) { + if m.MockLinkAuthorization != nil { + m.MockLinkAuthorization(ctx, o) + return + } + return +} + +// LinkOrderAccountsByID mock. +func (m *MockLinker) LinkOrderAccountsByID(ctx context.Context, orders []string) { + if m.MockLinkOrdersByAccountID != nil { + m.MockLinkOrdersByAccountID(ctx, orders) + return + } + return +} diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 750b019d..c6c855a8 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -170,13 +170,13 @@ func TestHandler_AddNonce(t *testing.T) { } } -func TestHandlerAddDirLink(t *testing.T) { +func TestHandler_addDirLink(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - db acme.DB link string + linker Linker statusCode int ctx context.Context err *acme.Error @@ -186,14 +186,7 @@ func TestHandlerAddDirLink(t *testing.T) { ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - db: &acme.MockDB{ - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, baseURLFromContext(ctx), baseURL) - return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) - }, - */ - }, + linker: NewLinker("dns", "acme"), ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, @@ -203,7 +196,7 @@ func TestHandlerAddDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -233,7 +226,7 @@ func TestHandlerAddDirLink(t *testing.T) { } } -func TestHandlerVerifyContentType(t *testing.T) { +func TestHandler_VerifyContentType(t *testing.T) { prov := newProv() provName := prov.GetName() baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -250,14 +243,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("/acme/%s/certificate/", provName) - }, - }, + linker: NewLinker("dns", "acme"), }, url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), ctx: context.WithValue(context.Background(), provisionerContextKey, prov), @@ -269,14 +255,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", @@ -287,14 +266,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok": func(t *testing.T) test { return test{ h: Handler{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", @@ -304,14 +276,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ h: Handler{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkix-cert", @@ -321,14 +286,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/jose+json": func(t *testing.T) test { return test{ h: Handler{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", @@ -338,14 +296,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ h: Handler{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkcs7-mime", @@ -771,6 +722,7 @@ func TestHandlerLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { + linker Linker db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) @@ -806,14 +758,7 @@ func TestHandlerLookupJWK(t *testing.T) { ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return prefix - }, - }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), @@ -837,14 +782,7 @@ func TestHandlerLookupJWK(t *testing.T) { ctx = context.WithValue(ctx, jwsContextKey, _parsed) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - db: &acme.MockDB{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, - }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), @@ -855,24 +793,16 @@ func TestHandlerLookupJWK(t *testing.T) { ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + linker: NewLinker("dns", "acme"), db: &acme.MockDB{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) + MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) return nil, database.ErrNotFound }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, statusCode: 400, - err: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/GetAccount-error": func(t *testing.T) test { @@ -880,20 +810,12 @@ func TestHandlerLookupJWK(t *testing.T) { ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + linker: NewLinker("dns", "acme"), db: &acme.MockDB{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, accID) + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) return nil, acme.NewErrorISE("force") }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, statusCode: 500, @@ -906,24 +828,16 @@ func TestHandlerLookupJWK(t *testing.T) { ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + linker: NewLinker("dns", "acme"), db: &acme.MockDB{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, accID) + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) return acc, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, statusCode: 401, - err: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, "ok": func(t *testing.T) test { @@ -932,27 +846,19 @@ func TestHandlerLookupJWK(t *testing.T) { ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + linker: NewLinker("dns", "acme"), db: &acme.MockDB{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, accID) + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) return acc, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := acme.AccountFromContext(r.Context()) + _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) - _jwk, err := acme.JwkFromContext(r.Context()) + _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) @@ -964,7 +870,7 @@ func TestHandlerLookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: tc.linker} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -978,7 +884,7 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { - var ae acme.AError + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) @@ -1057,7 +963,7 @@ func TestHandlerExtractJWK(t *testing.T) { return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("jwk expected in protected header")), + err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), } }, "fail/invalid-jwk": func(t *testing.T) test { @@ -1075,7 +981,7 @@ func TestHandlerExtractJWK(t *testing.T) { return test{ ctx: ctx, statusCode: 400, - err: acme.MalformedErr(errors.New("invalid jwk in protected header")), + err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { @@ -1084,11 +990,8 @@ func TestHandlerExtractJWK(t *testing.T) { return test{ ctx: ctx, db: &acme.MockDB{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) return nil, acme.NewErrorISE("force") }, }, @@ -1103,16 +1006,13 @@ func TestHandlerExtractJWK(t *testing.T) { return test{ ctx: ctx, db: &acme.MockDB{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) return acc, nil }, }, statusCode: 401, - err: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, "ok": func(t *testing.T) test { @@ -1122,19 +1022,16 @@ func TestHandlerExtractJWK(t *testing.T) { return test{ ctx: ctx, db: &acme.MockDB{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) return acc, nil }, }, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := acme.AccountFromContext(r.Context()) + _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) - _jwk, err := acme.JwkFromContext(r.Context()) + _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) @@ -1148,19 +1045,16 @@ func TestHandlerExtractJWK(t *testing.T) { return test{ ctx: ctx, db: &acme.MockDB{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) return nil, database.ErrNotFound }, }, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := acme.AccountFromContext(r.Context()) + _acc, err := accountFromContext(r.Context()) assert.NotNil(t, err) assert.Nil(t, _acc) - _jwk, err := acme.JwkFromContext(r.Context()) + _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) @@ -1172,7 +1066,7 @@ func TestHandlerExtractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -1186,7 +1080,7 @@ func TestHandlerExtractJWK(t *testing.T) { assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { - var ae acme.AError + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) @@ -1229,7 +1123,7 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, - err: acme.MalformedErr(errors.New("request body does not contain a signature")), + err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), } }, "fail/more-than-one-signature": func(t *testing.T) test { @@ -1242,7 +1136,7 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.New("request body contains more than one signature")), + err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), } }, "fail/unprotected-header-not-empty": func(t *testing.T) test { @@ -1254,7 +1148,7 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.New("unprotected header must not be used")), + err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), } }, "fail/unsuitable-algorithm-none": func(t *testing.T) test { @@ -1266,7 +1160,7 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.New("unsuitable algorithm: none")), + err: acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: none"), } }, "fail/unsuitable-algorithm-mac": func(t *testing.T) test { @@ -1278,7 +1172,7 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), + err: acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", jose.HS256), } }, "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { @@ -1300,13 +1194,13 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), + err: acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match"), } }, "fail/rsa-key-too-small": func(t *testing.T) test { @@ -1328,13 +1222,13 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), + err: acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least 2048 bits (256 bytes) in size"), } }, "fail/UseNonce-error": func(t *testing.T) test { @@ -1345,7 +1239,7 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return acme.NewErrorISE("force") }, }, @@ -1362,13 +1256,13 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.New("jws missing url protected header")), + err: acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"), } }, "fail/url-mismatch": func(t *testing.T) test { @@ -1386,13 +1280,13 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), + err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url), } }, "fail/both-jwk-kid": func(t *testing.T) test { @@ -1415,13 +1309,13 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), + err: acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"), } }, "fail/no-jwk-kid": func(t *testing.T) test { @@ -1439,13 +1333,13 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), + err: acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"), } }, "ok/kid": func(t *testing.T) test { @@ -1464,7 +1358,7 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, @@ -1494,7 +1388,7 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, @@ -1524,7 +1418,7 @@ func TestHandlerValidateJWS(t *testing.T) { } return test{ db: &acme.MockDB{ - useNonce: func(n string) error { + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, @@ -1539,7 +1433,7 @@ func TestHandlerValidateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -1553,7 +1447,7 @@ func TestHandlerValidateJWS(t *testing.T) { assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { - var ae acme.AError + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) diff --git a/acme/api/order_test.go b/acme/api/order_test.go index a1c8fef7..610713b6 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -14,7 +14,6 @@ import ( "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "go.step.sm/crypto/pemutil" @@ -30,7 +29,7 @@ func TestNewOrderRequestValidate(t *testing.T) { "fail/no-identifiers": func(t *testing.T) test { return test{ nor: &NewOrderRequest{}, - err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")), + err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, "fail/bad-identifier": func(t *testing.T) test { @@ -41,7 +40,7 @@ func TestNewOrderRequestValidate(t *testing.T) { {Type: "foo", Value: "bar.com"}, }, }, - err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")), + err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"), } }, "ok": func(t *testing.T) test { @@ -105,7 +104,7 @@ func TestFinalizeRequestValidate(t *testing.T) { "fail/parse-csr-error": func(t *testing.T) test { return test{ fr: &FinalizeRequest{}, - err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")), + err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, "fail/invalid-csr-signature": func(t *testing.T) test { @@ -117,7 +116,7 @@ func TestFinalizeRequestValidate(t *testing.T) { fr: &FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(c.Raw), }, - err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")), + err: acme.NewError(acme.ErrorMalformedType, "csr failed signature check: x509: ECDSA verification failure"), } }, "ok": func(t *testing.T) test { @@ -148,15 +147,15 @@ func TestFinalizeRequestValidate(t *testing.T) { } } -func TestHandlerGetOrder(t *testing.T) { +func TestHandler_GetOrder(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) nbf := time.Now().UTC() naf := time.Now().UTC().Add(24 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry.Format(time.RFC3339), - NotBefore: nbf.Format(time.RFC3339), - NotAfter: naf.Format(time.RFC3339), + Expires: expiry, + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ { Type: "dns", @@ -167,8 +166,8 @@ func TestHandlerGetOrder(t *testing.T) { Value: "*.smallstep.com", }, }, - Status: "pending", - Authorizations: []string{"foo", "bar"}, + Status: "pending", + AuthorizationURLs: []string{"foo", "bar"}, } // Request with chi context @@ -181,67 +180,57 @@ func TestHandlerGetOrder(t *testing.T) { baseURL.String(), provName, o.ID) type test struct { - auth acme.Interface + db acme.DB + linker Linker ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/getOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getOrder: func(ctx context.Context, accID, id string) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { assert.Equals(t, id, o.ID) return &o, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return url - }, }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 200, } @@ -250,7 +239,7 @@ func TestHandlerGetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: tc.linker, db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -263,15 +252,14 @@ func TestHandlerGetOrder(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) @@ -290,15 +278,15 @@ func TestHandlerNewOrder(t *testing.T) { naf := nbf.Add(17 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry.Format(time.RFC3339), - NotBefore: nbf.Format(time.RFC3339), - NotAfter: naf.Format(time.RFC3339), + Expires: expiry, + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "dns", Value: "bar.com"}, }, - Status: "pending", - Authorizations: []string{"foo", "bar"}, + Status: "pending", + AuthorizationURLs: []string{"foo", "bar"}, } prov := newProv() @@ -308,58 +296,59 @@ func TestHandlerNewOrder(t *testing.T) { baseURL.String(), provName) type test struct { - auth acme.Interface + db acme.DB + linker Linker ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -367,13 +356,13 @@ func TestHandlerNewOrder(t *testing.T) { nor := &NewOrderRequest{} b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")), + err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, "fail/NewOrder-error": func(t *testing.T) test { @@ -386,23 +375,18 @@ func TestHandlerNewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ - auth: &mockAcmeAuthority{ - newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.AccountID, acc.ID) - assert.Equals(t, ops.Identifiers, nor.Identifiers) - return nil, acme.MalformedErr(errors.New("force")) + db: &acme.MockDB{ + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + return acme.NewError(acme.ErrorMalformedType, "force") }, }, ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("force")), + err: acme.NewError(acme.ErrorMalformedType, "force"), } }, "ok": func(t *testing.T) test { @@ -417,29 +401,17 @@ func TestHandlerNewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.AccountID, acc.ID) - assert.Equals(t, ops.Identifiers, nor.Identifiers) - assert.Equals(t, ops.NotBefore, nbf) - assert.Equals(t, ops.NotAfter, naf) - return &o, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) + db: &acme.MockDB{ + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + return nil }, }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 201, } @@ -454,30 +426,17 @@ func TestHandlerNewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.AccountID, acc.ID) - assert.Equals(t, ops.Identifiers, nor.Identifiers) - - assert.True(t, ops.NotBefore.IsZero()) - assert.True(t, ops.NotAfter.IsZero()) - return &o, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) + db: &acme.MockDB{ + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + return nil }, }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 201, } @@ -486,7 +445,7 @@ func TestHandlerNewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: tc.linker, db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -499,15 +458,14 @@ func TestHandlerNewOrder(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) @@ -522,22 +480,22 @@ func TestHandlerNewOrder(t *testing.T) { } } -func TestHandlerFinalizeOrder(t *testing.T) { +func TestHandler_FinalizeOrder(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) nbf := time.Now().UTC().Add(5 * time.Hour) naf := nbf.Add(17 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry.Format(time.RFC3339), - NotBefore: nbf.Format(time.RFC3339), - NotAfter: naf.Format(time.RFC3339), + Expires: expiry, + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "dns", Value: "bar.com"}, }, - Status: "valid", - Authorizations: []string{"foo", "bar"}, - Certificate: "https://ca.smallstep.com/acme/certificate/certID", + Status: "valid", + AuthorizationURLs: []string{"foo", "bar"}, + CertificateURL: "https://ca.smallstep.com/acme/certificate/certID", } _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") assert.FatalError(t, err) @@ -554,60 +512,61 @@ func TestHandlerFinalizeOrder(t *testing.T) { baseURL.String(), provName, o.ID) type test struct { - auth acme.Interface + db acme.DB + linker Linker ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -615,13 +574,13 @@ func TestHandlerFinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")), + err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, "fail/FinalizeOrder-error": func(t *testing.T) test { @@ -631,25 +590,27 @@ func TestHandlerFinalizeOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - return nil, acme.MalformedErr(errors.New("force")) + db: &acme.MockDB{ + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + /* + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, o.ID) + assert.Equals(t, incsr.Raw, csr.Raw) + */ + return acme.NewError(acme.ErrorMalformedType, "force") }, }, ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("force")), + err: acme.NewError(acme.ErrorMalformedType, "force"), } }, "ok": func(t *testing.T) test { @@ -659,28 +620,25 @@ func TestHandlerFinalizeOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - return &o, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("%s/acme/%s/order/%s", - baseURL.String(), provName, o.ID) + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + /* + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, o.ID) + assert.Equals(t, incsr.Raw, csr.Raw) + return &o, nil + */ + return nil }, }, ctx: ctx, @@ -691,7 +649,7 @@ func TestHandlerFinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: tc.linker, db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -704,15 +662,14 @@ func TestHandlerFinalizeOrder(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) diff --git a/acme/authority_test.go b/acme/authority_test.go index 8861c15e..0e8de984 100644 --- a/acme/authority_test.go +++ b/acme/authority_test.go @@ -1,25 +1,6 @@ package acme -import ( - "context" - "crypto" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql/database" - "go.step.sm/crypto/jose" -) - +/* func TestAuthorityGetLink(t *testing.T) { auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) assert.FatalError(t, err) @@ -1737,3 +1718,4 @@ func TestAuthorityDeactivateAccount(t *testing.T) { }) } } +*/ diff --git a/acme/authz_test.go b/acme/authz_test.go index 31e6bb58..206921c6 100644 --- a/acme/authz_test.go +++ b/acme/authz_test.go @@ -1,20 +1,7 @@ package acme -import ( - "context" - "encoding/json" - "strings" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - -func newAz() (authz, error) { +/* +func newAz() (*Authorization, error) { mockdb := &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil @@ -834,3 +821,4 @@ func TestAuthzUpdateStatus(t *testing.T) { }) } } +*/ diff --git a/acme/certificate_test.go b/acme/certificate_test.go index a4b8f91a..adbf8e00 100644 --- a/acme/certificate_test.go +++ b/acme/certificate_test.go @@ -1,20 +1,6 @@ package acme -import ( - "crypto/x509" - "encoding/json" - "encoding/pem" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" - "go.step.sm/crypto/pemutil" -) - +/* func defaultCertOps() (*CertOptions, error) { crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt") if err != nil { @@ -36,7 +22,7 @@ func defaultCertOps() (*CertOptions, error) { }, nil } -func newcert() (*certificate, error) { +func newcert() (*Certificate, error) { ops, err := defaultCertOps() if err != nil { return nil, err @@ -251,3 +237,4 @@ func TestCertificateToACME(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert) } +*/ diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 87ec0c4c..11b30961 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -1,38 +1,6 @@ package acme -import ( - "bytes" - "context" - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "encoding/base64" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "math/big" - "net" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" - "go.step.sm/crypto/jose" -) - +/* var testOps = ChallengeOptions{ AccountID: "accID", AuthzID: "authzID", @@ -42,7 +10,7 @@ var testOps = ChallengeOptions{ }, } -func newDNSCh() (challenge, error) { +func newDNSCh() (Challenge, error) { mockdb := &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil @@ -51,7 +19,7 @@ func newDNSCh() (challenge, error) { return newDNS01Challenge(mockdb, testOps) } -func newTLSALPNCh() (challenge, error) { +func newTLSALPNCh() (Challenge, error) { mockdb := &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil @@ -60,7 +28,7 @@ func newTLSALPNCh() (challenge, error) { return newTLSALPN01Challenge(mockdb, testOps) } -func newHTTPCh() (challenge, error) { +func newHTTPCh() (Challenge, error) { mockdb := &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil @@ -69,7 +37,7 @@ func newHTTPCh() (challenge, error) { return newHTTP01Challenge(mockdb, testOps) } -func newHTTPChWithServer(host string) (challenge, error) { +func newHTTPChWithServer(host string) (Challenge, error) { mockdb := &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil @@ -1992,3 +1960,4 @@ func TestDNS01Validate(t *testing.T) { }) } } +*/ diff --git a/acme/nonce_test.go b/acme/nonce_test.go index 6aa467a0..2088d39b 100644 --- a/acme/nonce_test.go +++ b/acme/nonce_test.go @@ -1,16 +1,6 @@ package acme -import ( - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - +/* func TestNewNonce(t *testing.T) { type test struct { db nosql.DB @@ -161,3 +151,4 @@ func TestUseNonce(t *testing.T) { }) } } +*/ diff --git a/acme/order_test.go b/acme/order_test.go index e6a8f057..5bd21fdb 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -1,24 +1,6 @@ package acme -import ( - "context" - "crypto/x509" - "crypto/x509/pkix" - "encoding/json" - "fmt" - "net" - "net/url" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - +/* var certDuration = 6 * time.Hour func defaultOrderOps() OrderOptions { @@ -1735,3 +1717,4 @@ func Test_getOrderIDsByAccount(t *testing.T) { }) } } +*/ diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index 25d74b9d..8debadde 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -31,7 +31,7 @@ func TestNewACMEClient(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", NewAccount: srv.URL + "/bar", NewOrder: srv.URL + "/baz", @@ -58,7 +58,7 @@ func TestNewACMEClient(t *testing.T) { "fail/get-directory": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -76,7 +76,7 @@ func TestNewACMEClient(t *testing.T) { ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: dir, rc1: 200, - r2: acme.AccountDoesNotExistErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), rc2: 400, err: errors.New("Account does not exist"), } @@ -142,7 +142,7 @@ func TestNewACMEClient(t *testing.T) { func TestACMEClient_GetDirectory(t *testing.T) { c := &ACMEClient{ - dir: &acme.Directory{ + dir: &acmeAPI.Directory{ NewNonce: "/foo", NewAccount: "/bar", NewOrder: "/baz", @@ -166,7 +166,7 @@ func TestACMEClient_GetNonce(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -185,7 +185,7 @@ func TestACMEClient_GetNonce(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/GET-nonce": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -237,7 +237,7 @@ func TestACMEClient_post(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -266,7 +266,7 @@ func TestACMEClient_post(t *testing.T) { "fail/account-not-configured": func(t *testing.T) test { return test{ client: &ACMEClient{}, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("acme client not configured with account"), } @@ -274,7 +274,7 @@ func TestACMEClient_post(t *testing.T) { "fail/GET-nonce": func(t *testing.T) test { return test{ client: ac, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -365,7 +365,7 @@ func TestACMEClient_NewOrder(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", NewOrder: srv.URL + "/bar", } @@ -387,9 +387,9 @@ func TestACMEClient_NewOrder(t *testing.T) { norb, err := json.Marshal(nor) assert.FatalError(t, err) ord := acme.Order{ - Status: "valid", - Expires: "soon", - Finalize: "finalize-url", + Status: "valid", + Expires: time.Now(), // "soon" + FinalizeURL: "finalize-url", } ac := &ACMEClient{ client: &http.Client{ @@ -404,7 +404,7 @@ func TestACMEClient_NewOrder(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -413,7 +413,7 @@ func TestACMEClient_NewOrder(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, ops: []withHeaderOption{withKid(ac)}, err: errors.New("The request message was malformed"), @@ -498,7 +498,7 @@ func TestACMEClient_GetOrder(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -509,9 +509,9 @@ func TestACMEClient_GetOrder(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ord := acme.Order{ - Status: "valid", - Expires: "soon", - Finalize: "finalize-url", + Status: "valid", + Expires: time.Now(), // "soon" + FinalizeURL: "finalize-url", } ac := &ACMEClient{ client: &http.Client{ @@ -526,7 +526,7 @@ func TestACMEClient_GetOrder(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -535,7 +535,7 @@ func TestACMEClient_GetOrder(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -618,7 +618,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -628,9 +628,9 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - az := acme.Authz{ + az := acme.Authorization{ Status: "valid", - Expires: "soon", + Expires: time.Now(), Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, } ac := &ACMEClient{ @@ -646,7 +646,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -655,7 +655,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -738,7 +738,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -766,7 +766,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -775,7 +775,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -859,7 +859,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -887,7 +887,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -896,7 +896,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -976,7 +976,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -987,10 +987,10 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ord := acme.Order{ - Status: "valid", - Expires: "soon", - Finalize: "finalize-url", - Certificate: "cert-url", + Status: "valid", + Expires: time.Now(), // "soon" + FinalizeURL: "finalize-url", + CertificateURL: "cert-url", } _csr, err := pemutil.Read("../authority/testdata/certs/foo.csr") assert.FatalError(t, err) @@ -1012,7 +1012,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -1021,7 +1021,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -1101,7 +1101,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -1137,7 +1137,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { "fail/client-post": func(t *testing.T) test { return test{ client: ac, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -1147,7 +1147,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { client: ac, r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -1232,7 +1232,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -1268,7 +1268,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -1277,7 +1277,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } From f71e27e787fff9feb9b62b8aec04c5cacaae7bca Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 10 Mar 2021 23:05:46 -0800 Subject: [PATCH 16/47] [acme db interface] unit test progress --- acme/api/account_test.go | 4 ++-- acme/api/handler.go | 12 ++++++++---- acme/api/handler_test.go | 37 ++++++++++++++----------------------- acme/api/middleware.go | 2 +- acme/api/middleware_test.go | 30 +++++++++++++++--------------- acme/certificate.go | 13 ------------- acme/errors.go | 10 +++++++--- 7 files changed, 47 insertions(+), 61 deletions(-) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index d94819c7..831a218a 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -40,7 +40,7 @@ func newProv() provisioner.Interface { return p } -func TestNewAccountRequestValidate(t *testing.T) { +func TestNewAccountRequest_Validate(t *testing.T) { type test struct { nar *NewAccountRequest err *acme.Error @@ -96,7 +96,7 @@ func TestNewAccountRequestValidate(t *testing.T) { } } -func TestUpdateAccountRequestValidate(t *testing.T) { +func TestUpdateAccountRequest_Validate(t *testing.T) { type test struct { uar *UpdateAccountRequest err *acme.Error diff --git a/acme/api/handler.go b/acme/api/handler.go index a6d35bb3..3fe72d54 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -2,7 +2,9 @@ package api import ( "crypto/tls" + "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "net" "net/http" @@ -259,10 +261,12 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { return } - certBytes, err := cert.ToACME() - if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error converting cert to ACME representation")) - return + var certBytes []byte + for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { + certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: c.Raw, + })...) } api.LogCertificate(w, cert.Leaf) diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 23db72a5..8621ca18 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto/x509" "encoding/json" "encoding/pem" "fmt" @@ -47,7 +48,7 @@ func TestHandler_GetNonce(t *testing.T) { } func TestHandler_GetDirectory(t *testing.T) { - linker := NewLinker("acme", "ca.smallstep.com") + linker := NewLinker("ca.smallstep.com", "acme") prov := newProv() provName := url.PathEscape(prov.GetName()) @@ -306,7 +307,7 @@ func TestHandler_GetCertificate(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getCertificate-error": func(t *testing.T) test { + "fail/db.GetCertificate-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -319,7 +320,7 @@ func TestHandler_GetCertificate(t *testing.T) { err: acme.NewErrorISE("force"), } }, - "fail/decode-leaf-for-loggger": func(t *testing.T) test { + "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -327,28 +328,12 @@ func TestHandler_GetCertificate(t *testing.T) { db: &acme.MockDB{ MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return &acme.Certificate{}, nil + return &acme.Certificate{AccountID: "foo"}, nil }, }, ctx: ctx, - statusCode: 500, - err: acme.NewErrorISE("failed to decode any certificates from generated certBytes"), - } - }, - "fail/parse-x509-leaf-for-logger": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { - assert.Equals(t, id, certID) - return &acme.Certificate{}, nil - }, - }, - ctx: ctx, - statusCode: 500, - err: acme.NewErrorISE("failed to parse generated leaf certificate"), + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, "ok": func(t *testing.T) test { @@ -359,7 +344,13 @@ func TestHandler_GetCertificate(t *testing.T) { db: &acme.MockDB{ MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return &acme.Certificate{}, nil + return &acme.Certificate{ + AccountID: "accID", + OrderID: "ordID", + Leaf: leaf, + Intermediates: []*x509.Certificate{inter, root}, + ID: id, + }, nil }, }, ctx: ctx, diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 7a3529cd..f2a35c3a 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -411,7 +411,7 @@ const ( func accountFromContext(ctx context.Context) (*acme.Account, error) { val, ok := ctx.Value(accContextKey).(*acme.Account) if !ok || val == nil { - return nil, acme.NewErrorISE("account not in context") + return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context") } return val, nil } diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index c6c855a8..4f2c4bcb 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -81,7 +81,7 @@ func Test_baseURLFromRequest(t *testing.T) { } } -func TestHandlerBaseURLFromRequest(t *testing.T) { +func TestHandler_baseURLFromRequest(t *testing.T) { h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req.Host = "test.ca.smallstep.com:8080" @@ -107,7 +107,7 @@ func TestHandlerBaseURLFromRequest(t *testing.T) { h.baseURLFromRequest(next)(w, req) } -func TestHandler_AddNonce(t *testing.T) { +func TestHandler_addNonce(t *testing.T) { url := "https://ca.smallstep.com/acme/new-nonce" type test struct { db acme.DB @@ -226,7 +226,7 @@ func TestHandler_addDirLink(t *testing.T) { } } -func TestHandler_VerifyContentType(t *testing.T) { +func TestHandler_verifyContentType(t *testing.T) { prov := newProv() provName := prov.GetName() baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -340,7 +340,7 @@ func TestHandler_VerifyContentType(t *testing.T) { } } -func TestHandlerIsPostAsGet(t *testing.T) { +func TestHandler_isPostAsGet(t *testing.T) { url := "https://ca.smallstep.com/acme/new-account" type test struct { ctx context.Context @@ -417,7 +417,7 @@ func (errReader) Close() error { return nil } -func TestHandlerParseJWS(t *testing.T) { +func TestHandler_parseJWS(t *testing.T) { url := "https://ca.smallstep.com/acme/new-account" type test struct { next nextHTTP @@ -498,7 +498,7 @@ func TestHandlerParseJWS(t *testing.T) { } } -func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { +func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := jwk.Public() @@ -558,7 +558,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _pub := _jwk.Public() ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwsContextKey, &_pub) + ctx = context.WithValue(ctx, jwkContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, @@ -570,7 +570,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { clone := &_pub clone.Algorithm = jose.HS256 ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwsContextKey, clone) + ctx = context.WithValue(ctx, jwkContextKey, clone) return test{ ctx: ctx, statusCode: 400, @@ -579,7 +579,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { }, "ok": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwsContextKey, pub) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -600,7 +600,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { clone := &_pub clone.Algorithm = "" ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwsContextKey, pub) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -624,7 +624,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) - ctx = context.WithValue(ctx, jwsContextKey, pub) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -648,7 +648,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) - ctx = context.WithValue(ctx, jwsContextKey, pub) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -697,7 +697,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { } } -func TestHandlerLookupJWK(t *testing.T) { +func TestHandler_lookupJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -899,7 +899,7 @@ func TestHandlerLookupJWK(t *testing.T) { } } -func TestHandlerExtractJWK(t *testing.T) { +func TestHandler_extractJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1095,7 +1095,7 @@ func TestHandlerExtractJWK(t *testing.T) { } } -func TestHandlerValidateJWS(t *testing.T) { +func TestHandler_validateJWS(t *testing.T) { url := "https://ca.smallstep.com/acme/account/1234" type test struct { db acme.DB diff --git a/acme/certificate.go b/acme/certificate.go index daf9556b..d46d1a08 100644 --- a/acme/certificate.go +++ b/acme/certificate.go @@ -2,7 +2,6 @@ package acme import ( "crypto/x509" - "encoding/pem" ) // Certificate options with which to create and store a cert object. @@ -13,15 +12,3 @@ type Certificate struct { Leaf *x509.Certificate Intermediates []*x509.Certificate } - -// ToACME encodes the entire X509 chain into a PEM list. -func (cert *Certificate) ToACME() ([]byte, error) { - var ret []byte - for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { - ret = append(ret, pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: c.Raw, - })...) - } - return ret, nil -} diff --git a/acme/errors.go b/acme/errors.go index 41305c87..54182ec2 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -271,6 +271,10 @@ type Error struct { // NewError creates a new Error type. func NewError(pt ProblemType, msg string, args ...interface{}) *Error { + return newError(pt, errors.Errorf(msg, args...)) +} + +func newError(pt ProblemType, err error) *Error { meta, ok := errorMap[pt] if !ok { meta = errorServerInternalMetadata @@ -278,7 +282,7 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error { Type: meta.typ, Detail: meta.details, Status: meta.status, - Err: errors.Errorf("unrecognized problemType %v", pt), + Err: err, } } @@ -286,7 +290,7 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error { Type: meta.typ, Detail: meta.details, Status: meta.status, - Err: errors.Errorf(msg, args...), + Err: err, } } @@ -308,7 +312,7 @@ func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Err } return e default: - return NewError(ErrorServerInternalType, msg, args...) + return newError(typ, errors.Wrapf(err, msg, args...)) } } From 291fd5d45a57ef90eddc4923b15aee9293635beb Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 10 Mar 2021 23:59:02 -0800 Subject: [PATCH 17/47] [acme db interface] more unit tests --- acme/api/handler.go | 1 + acme/api/handler_test.go | 217 ++++++++++++++++++++++----------------- acme/api/middleware.go | 2 +- 3 files changed, 127 insertions(+), 93 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index 3fe72d54..5960d49c 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -171,6 +171,7 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { } if err = az.UpdateStatus(ctx, h.db); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + return } h.linker.LinkAuthorization(ctx, az) diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 8621ca18..34c720f1 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -16,7 +16,6 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) @@ -110,10 +109,11 @@ func TestHandler_GetDirectory(t *testing.T) { } } -func TestHandler_GetAuthz(t *testing.T) { +func TestHandler_GetAuthorization(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) az := acme.Authorization{ - ID: "authzID", + ID: "authzID", + AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "example.com", @@ -147,7 +147,7 @@ func TestHandler_GetAuthz(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", az.ID) - url := fmt.Sprintf("%s/acme/%s/challenge/%s", + url := fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, az.ID) type test struct { @@ -175,7 +175,7 @@ func TestHandler_GetAuthz(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getAuthz-error": func(t *testing.T) test { + "fail/db.GetAuthorization-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -188,6 +188,48 @@ func TestHandler_GetAuthz(t *testing.T) { err: acme.NewErrorISE("force"), } }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { + assert.Equals(t, id, az.ID) + return &acme.Authorization{ + AccountID: "foo", + }, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/db.UpdateAuthorization-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { + assert.Equals(t, id, az.ID) + return &acme.Authorization{ + AccountID: "accID", + Status: acme.StatusPending, + Expires: time.Now().Add(-1 * time.Hour), + }, nil + }, + MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + assert.Equals(t, az.Status, acme.StatusInvalid) + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) @@ -200,15 +242,6 @@ func TestHandler_GetAuthz(t *testing.T) { assert.Equals(t, id, az.ID) return &az, nil }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.True(t, abs) - assert.Equals(t, in, []string{az.ID}) - return url - }, - */ }, ctx: ctx, statusCode: 200, @@ -218,7 +251,7 @@ func TestHandler_GetAuthz(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -402,7 +435,7 @@ func ch() acme.Challenge { } } -func TestHandlerGetChallenge(t *testing.T) { +func TestHandler_GetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") prov := newProv() @@ -437,8 +470,8 @@ func TestHandlerGetChallenge(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, - statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + statusCode: 400, + err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { @@ -448,88 +481,88 @@ func TestHandlerGetChallenge(t *testing.T) { ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, - statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), - } - }, - "fail/validate-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - }, - ctx: ctx, - statusCode: 401, - err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - }, - ctx: ctx, - statusCode: 401, - err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + statusCode: 400, + err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), } }, - "ok/validate-challenge": func(t *testing.T) test { - key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ch := ch() - ch.Status = "valid" - ch.Validated = time.Now().UTC().Format(time.RFC3339) - return test{ - db: &acme.MockDB{ - MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { - assert.Equals(t, chID, ch.ID) - return &ch, nil + /* + "fail/validate-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - var ret string - switch count { - case 0: - assert.Equals(t, typ, acme.AuthzLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) - case 1: - assert.Equals(t, typ, acme.ChallengeLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.ID}) - ret = url - } - count++ - return ret + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + } + }, + "fail/get-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + } + }, + "ok/validate-challenge": func(t *testing.T) test { + key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ID: "accID", Key: key} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ch := ch() + ch.Status = "valid" + ch.Validated = time.Now().UTC().Format(time.RFC3339) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, ch.ID) + return &ch, nil }, - */ - }, - ctx: ctx, - statusCode: 200, - ch: ch, - } - }, + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + var ret string + switch count { + case 0: + assert.Equals(t, typ, acme.AuthzLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.AuthzID}) + ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) + case 1: + assert.Equals(t, typ, acme.ChallengeLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.ID}) + ret = url + } + count++ + return ret + }, + }, + ctx: ctx, + statusCode: 200, + ch: ch, + } + }, + */ } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() diff --git a/acme/api/middleware.go b/acme/api/middleware.go index f2a35c3a..a021c936 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -462,7 +462,7 @@ func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { func payloadFromContext(ctx context.Context) (*payloadInfo, error) { val, ok := ctx.Value(payloadContextKey).(*payloadInfo) if !ok || val == nil { - return nil, acme.NewErrorISE("payload expected in request context") + return nil, acme.NewError(acme.ErrorMalformedType, "payload expected in request context") } return val, nil } From 20b9785d2002443966cfef44663d6e8fbe65639e Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 11 Mar 2021 13:10:14 -0800 Subject: [PATCH 18/47] [acme db interface] continuing unit test work --- acme/api/account.go | 2 +- acme/api/account_test.go | 80 ++++--------- acme/api/handler.go | 36 +++--- acme/api/handler_test.go | 247 +++++++++++++++++++++++++++------------ acme/api/middleware.go | 2 +- acme/challenge.go | 12 +- 6 files changed, 226 insertions(+), 153 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index 30d406e4..2e15ad40 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -105,7 +105,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } - acc := &acme.Account{ + acc = &acme.Account{ Key: jwk, Contact: nar.Contact, Status: acme.StatusValid, diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 831a218a..d8fdff84 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -278,18 +278,13 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { } func TestHandler_NewAccount(t *testing.T) { - accID := "accountID" - acc := acme.Account{ - ID: accID, - Status: "valid", - Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), - } prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { db acme.DB + acc *acme.Account ctx context.Context statusCode int err *acme.Error @@ -372,7 +367,7 @@ func TestHandler_NewAccount(t *testing.T) { err: acme.NewErrorISE("jwk expected in request context"), } }, - "fail/NewAccount-error": func(t *testing.T) test { + "fail/db.CreateAccount-error": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } @@ -410,20 +405,18 @@ func TestHandler_NewAccount(t *testing.T) { return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + acc.ID = "accountID" assert.Equals(t, acc.Contact, nar.Contact) assert.Equals(t, acc.Key, jwk) return nil }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - */ + }, + acc: &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Orders: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders", }, ctx: ctx, statusCode: 201, @@ -435,12 +428,21 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + } ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ ctx: ctx, + acc: acc, statusCode: 200, } }, @@ -448,7 +450,7 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -471,19 +473,19 @@ func TestHandler_NewAccount(t *testing.T) { assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(acc) + expB, err := json.Marshal(tc.acc) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), - provName, accID)}) + provName, "accountID")}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } -func TestHandlerGetUpdateAccount(t *testing.T) { +func TestHandler_GetUpdateAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ ID: accID, @@ -594,16 +596,6 @@ func TestHandlerGetUpdateAccount(t *testing.T) { assert.Equals(t, upd.ID, acc.ID) return nil }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - */ }, ctx: ctx, statusCode: 200, @@ -639,16 +631,6 @@ func TestHandlerGetUpdateAccount(t *testing.T) { assert.Equals(t, upd.ID, acc.ID) return nil }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - */ }, ctx: ctx, statusCode: 200, @@ -660,18 +642,6 @@ func TestHandlerGetUpdateAccount(t *testing.T) { ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - /* - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL, provName, accID) - }, - }, - */ ctx: ctx, statusCode: 200, } @@ -680,7 +650,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() diff --git a/acme/api/handler.go b/acme/api/handler.go index 5960d49c..47c93dfc 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -38,10 +38,11 @@ type payloadInfo struct { // Handler is the ACME API request handler. type Handler struct { - db acme.DB - backdate provisioner.Duration - ca acme.CertificateAuthority - linker Linker + db acme.DB + backdate provisioner.Duration + ca acme.CertificateAuthority + linker Linker + validateChallengeOptions *acme.ValidateChallengeOptions } // HandlerOptions required to create a new ACME API request handler. @@ -63,11 +64,24 @@ type HandlerOptions struct { // NewHandler returns a new ACME API handler. func NewHandler(ops HandlerOptions) api.RouterHandler { + client := http.Client{ + Timeout: time.Duration(30 * time.Second), + } + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + } return &Handler{ ca: ops.CA, db: ops.DB, backdate: ops.Backdate, linker: NewLinker(ops.DNS, ops.Prefix), + validateChallengeOptions: &acme.ValidateChallengeOptions{ + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, config) + }, + }, } } @@ -212,24 +226,12 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } - client := http.Client{ - Timeout: time.Duration(30 * time.Second), - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } - if err = ch.Validate(ctx, h.db, jwk, acme.ValidateOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }); err != nil { + if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) return } diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 34c720f1..70d2dc14 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -8,14 +8,17 @@ import ( "encoding/pem" "fmt" "io/ioutil" + "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/go-chi/chi" + "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) @@ -438,16 +441,21 @@ func ch() acme.Challenge { func TestHandler_GetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") + chiCtx.URLParams.Add("authzID", "authzID") prov := newProv() provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID") + + url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", + baseURL.String(), provName, "authzID", "chID") type test struct { db acme.DB + vco *acme.ValidateChallengeOptions ctx context.Context statusCode int - ch acme.Challenge + ch *acme.Challenge err *acme.Error } var tests = map[string]func(t *testing.T) test{ @@ -485,84 +493,177 @@ func TestHandler_GetChallenge(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), } }, - /* - "fail/validate-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + "fail/db.GetChallenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return nil, acme.NewErrorISE("force") }, - ctx: ctx, - statusCode: 401, - err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{AccountID: "foo"}, nil }, - ctx: ctx, - statusCode: 401, - err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - } - }, - "ok/validate-challenge": func(t *testing.T) test { - key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ch := ch() - ch.Status = "valid" - ch.Validated = time.Now().UTC().Format(time.RFC3339) - return test{ - db: &acme.MockDB{ - MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { - assert.Equals(t, chID, ch.ID) - return &ch, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - var ret string - switch count { - case 0: - assert.Equals(t, typ, acme.AuthzLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) - case 1: - assert.Equals(t, typ, acme.ChallengeLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.ID}) - ret = url - } - count++ - return ret - }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "accout id mismatch"), + } + }, + "fail/no-jwk": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{AccountID: "accID"}, nil }, - ctx: ctx, - statusCode: 200, - ch: ch, - } - }, - */ + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("missing jwk"), + } + }, + "fail/nil-jwk": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, jwkContextKey, nil) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{AccountID: "accID"}, nil + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("nil jwk"), + } + }, + "fail/validate-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + _pub := _jwk.Public() + ctx = context.WithValue(ctx, jwkContextKey, &_pub) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{ + Status: acme.StatusPending, + Type: "http-01", + AccountID: "accID", + }, nil + }, + MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.AccountID, "accID") + assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) + return acme.NewErrorISE("force") + }, + }, + vco: &acme.ValidateChallengeOptions{ + HTTPGet: func(string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + _pub := _jwk.Public() + ctx = context.WithValue(ctx, jwkContextKey, &_pub) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{ + ID: "chID", + AuthzID: "authzID", + Status: acme.StatusPending, + Type: "http-01", + AccountID: "accID", + }, nil + }, + MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.AccountID, "accID") + assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) + return nil + }, + }, + ch: &acme.Challenge{ + ID: "chID", + AuthzID: "authzID", + Status: acme.StatusPending, + Type: "http-01", + AccountID: "accID", + URL: url, + Error: acme.NewError(acme.ErrorConnectionType, "force"), + }, + vco: &acme.ValidateChallengeOptions{ + HTTPGet: func(string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + ctx: ctx, + statusCode: 200, + } + }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() diff --git a/acme/api/middleware.go b/acme/api/middleware.go index a021c936..f2a35c3a 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -462,7 +462,7 @@ func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { func payloadFromContext(ctx context.Context) (*payloadInfo, error) { val, ok := ctx.Value(payloadContextKey).(*payloadInfo) if !ok || val == nil { - return nil, acme.NewError(acme.ErrorMalformedType, "payload expected in request context") + return nil, acme.NewErrorISE("payload expected in request context") } return val, nil } diff --git a/acme/challenge.go b/acme/challenge.go index 2abc808c..2c6f5fb1 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -47,7 +47,7 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { // If already valid or invalid then return without performing validation. if ch.Status == StatusValid || ch.Status == StatusInvalid { return nil @@ -64,7 +64,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token) resp, err := vo.HTTPGet(url) @@ -105,7 +105,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return nil } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, ServerName: ch.Value, @@ -197,7 +197,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com @@ -263,8 +263,8 @@ type httpGetter func(string) (*http.Response, error) type lookupTxt func(string) ([]string, error) type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) -// ValidateOptions are ACME challenge validator functions. -type ValidateOptions struct { +// ValidateChallengeOptions are ACME challenge validator functions. +type ValidateChallengeOptions struct { HTTPGet httpGetter LookupTxt lookupTxt TLSDial tlsDialer From 8d2ebcfd497a1f062512a5f8222c1c93bd935221 Mon Sep 17 00:00:00 2001 From: max furman Date: Fri, 12 Mar 2021 00:16:48 -0800 Subject: [PATCH 19/47] [acme db interface] more unit tests --- acme/account.go | 10 -- acme/api/account.go | 20 ++- acme/api/account_test.go | 28 +-- acme/api/order_test.go | 371 +++++++++++++++++++++++++++++---------- acme/order.go | 4 +- 5 files changed, 314 insertions(+), 119 deletions(-) diff --git a/acme/account.go b/acme/account.go index 354ebdc7..cb60e21d 100644 --- a/acme/account.go +++ b/acme/account.go @@ -27,16 +27,6 @@ func (a *Account) ToLog() (interface{}, error) { return string(b), nil } -// GetID returns the account ID. -func (a *Account) GetID() string { - return a.ID -} - -// GetKey returns the JWK associated with the account. -func (a *Account) GetKey() *jose.JSONWebKey { - return a.Key -} - // IsValid returns true if the Account is valid. func (a *Account) IsValid() bool { return Status(a.Status) == StatusValid diff --git a/acme/api/account.go b/acme/api/account.go index 2e15ad40..c7f3d11a 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -153,15 +153,17 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - var err error - // If neither the status nor the contacts are being updated then ignore - // the updates and return 200. This conforms with the behavior detailed - // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). - acc.Status = uar.Status - acc.Contact = uar.Contact - if err = h.db.UpdateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) - return + if len(uar.Status) > 0 || len(uar.Contact) > 0 { + if len(uar.Status) > 0 { + acc.Status = uar.Status + } else if len(uar.Contact) > 0 { + acc.Contact = uar.Contact + } + + if err := h.db.UpdateAccount(ctx, acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + return + } } } diff --git a/acme/api/account_test.go b/acme/api/account_test.go index d8fdff84..28abffe1 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -168,16 +168,22 @@ func TestUpdateAccountRequest_Validate(t *testing.T) { } func TestHandler_GetOrdersByAccountID(t *testing.T) { - oids := []string{ - "https://ca.smallstep.com/acme/order/foo", - "https://ca.smallstep.com/acme/order/bar", + oids := []string{"foo", "bar"} + oidURLs := []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/foo", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/bar", } accID := "account-id" // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("accID", accID) - url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID) + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) type test struct { db acme.DB @@ -189,15 +195,15 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ db: &acme.MockDB{}, - ctx: ctx, + ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } @@ -213,7 +219,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"), } }, - "fail/getOrdersByAccount-error": func(t *testing.T) test { + "fail/db.GetOrdersByAccountID-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -230,6 +236,8 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { acc := &acme.Account{ID: accID} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { @@ -245,7 +253,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -268,7 +276,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(oids) + expB, err := json.Marshal(oidURLs) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -558,7 +566,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, - "fail/update-error": func(t *testing.T) test { + "fail/db.UpdateAccount-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 610713b6..b6783e34 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -19,7 +19,7 @@ import ( "go.step.sm/crypto/pemutil" ) -func TestNewOrderRequestValidate(t *testing.T) { +func TestNewOrderRequest_Validate(t *testing.T) { type test struct { nor *NewOrderRequest nbf, naf time.Time @@ -148,12 +148,12 @@ func TestFinalizeRequestValidate(t *testing.T) { } func TestHandler_GetOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC() - naf := time.Now().UTC().Add(24 * time.Hour) + now := clock.Now() + nbf := now + naf := now.Add(24 * time.Hour) + expiry := now.Add(-time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ @@ -166,8 +166,15 @@ func TestHandler_GetOrder(t *testing.T) { Value: "*.smallstep.com", }, }, - Status: "pending", - AuthorizationURLs: []string{"foo", "bar"}, + Expires: expiry, + Status: acme.StatusInvalid, + Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), + AuthorizationURLs: []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", + }, + FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", } // Request with chi context @@ -181,7 +188,6 @@ func TestHandler_GetOrder(t *testing.T) { type test struct { db acme.DB - linker Linker ctx context.Context statusCode int err *acme.Error @@ -203,8 +209,27 @@ func TestHandler_GetOrder(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getOrder-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/db.GetOrder-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -217,8 +242,64 @@ func TestHandler_GetOrder(t *testing.T) { err: acme.NewErrorISE("force"), } }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "foo"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/provisioner-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), + } + }, + "fail/order-update-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: clock.Now().Add(-time.Hour), + Status: acme.StatusReady, + }, nil + }, + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, "ok": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -226,11 +307,31 @@ func TestHandler_GetOrder(t *testing.T) { return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { - assert.Equals(t, id, o.ID) - return &o, nil + return &acme.Order{ + ID: "orderID", + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: expiry, + Status: acme.StatusReady, + AuthorizationIDs: []string{"foo", "bar", "baz"}, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + }, nil + }, + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return nil }, }, - linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 200, } @@ -239,7 +340,7 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker, db: tc.db} + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -264,6 +365,7 @@ func TestHandler_GetOrder(t *testing.T) { } else { expB, err := json.Marshal(o) assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -272,7 +374,7 @@ func TestHandler_GetOrder(t *testing.T) { } } -func TestHandlerNewOrder(t *testing.T) { +func TestHandler_NewOrder(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) nbf := time.Now().UTC().Add(5 * time.Hour) naf := nbf.Add(17 * time.Hour) @@ -297,7 +399,6 @@ func TestHandlerNewOrder(t *testing.T) { type test struct { db acme.DB - linker Linker ctx context.Context statusCode int err *acme.Error @@ -319,14 +420,23 @@ func TestHandlerNewOrder(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/no-payload": func(t *testing.T) test { + "fail/no-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner expected in request context"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("provisioner expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { @@ -408,10 +518,10 @@ func TestHandlerNewOrder(t *testing.T) { return test{ db: &acme.MockDB{ MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "orderID" return nil }, }, - linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 201, } @@ -436,7 +546,6 @@ func TestHandlerNewOrder(t *testing.T) { return nil }, }, - linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 201, } @@ -445,7 +554,7 @@ func TestHandlerNewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker, db: tc.db} + h := &Handler{linker: NewLinker("dns", "prefix"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -481,26 +590,33 @@ func TestHandlerNewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC().Add(5 * time.Hour) - naf := nbf.Add(17 * time.Hour) + now := clock.Now() + nbf := now + naf := now.Add(24 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, }, - Status: "valid", - AuthorizationURLs: []string{"foo", "bar"}, - CertificateURL: "https://ca.smallstep.com/acme/certificate/certID", + Expires: naf, + Status: acme.StatusValid, + AuthorizationURLs: []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", + }, + FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", + CertificateURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/certificate/certID", } - _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") - assert.FatalError(t, err) - csr, ok := _csr.(*x509.CertificateRequest) - assert.Fatal(t, ok) // Request with chi context chiCtx := chi.NewRouteContext() @@ -508,12 +624,22 @@ func TestHandler_FinalizeOrder(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/order/%s/finalize", + url := fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) + _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") + assert.FatalError(t, err) + csr, ok := _csr.(*x509.CertificateRequest) + assert.Fatal(t, ok) + + nor := &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + } + payloadBytes, err := json.Marshal(nor) + assert.FatalError(t, err) + type test struct { db acme.DB - linker Linker ctx context.Context statusCode int err *acme.Error @@ -521,7 +647,6 @@ func TestHandler_FinalizeOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -531,31 +656,49 @@ func TestHandler_FinalizeOrder(t *testing.T) { ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ - db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/no-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/no-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("paylod does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { @@ -583,62 +726,112 @@ func TestHandler_FinalizeOrder(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, - "fail/FinalizeOrder-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - nor := &FinalizeRequest{ - CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + "fail/db.GetOrder-error": func(t *testing.T) test { + + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), } - b, err := json.Marshal(nor) - assert.FatalError(t, err) + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ - MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { - /* - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - */ - return acme.NewError(acme.ErrorMalformedType, "force") + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "foo"}, nil }, }, ctx: ctx, - statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "force"), + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, - "ok": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - nor := &FinalizeRequest{ - CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + "fail/provisioner-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), } - b, err := json.Marshal(nor) - assert.FatalError(t, err) + }, + "fail/order-finalize-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: clock.Now().Add(-time.Hour), + Status: acme.StatusReady, + }, nil + }, MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { - /* - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - return &o, nil - */ - return nil + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + ID: "orderID", + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: naf, + Status: acme.StatusValid, + AuthorizationIDs: []string{"foo", "bar", "baz"}, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + CertificateID: "certID", + }, nil }, }, ctx: ctx, @@ -649,7 +842,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker, db: tc.db} + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -674,10 +867,12 @@ func TestHandler_FinalizeOrder(t *testing.T) { } else { expB, err := json.Marshal(o) assert.FatalError(t, err) + + ro := new(acme.Order) + err = json.Unmarshal(body, ro) + assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("%s/acme/%s/order/%s", - baseURL, provName, o.ID)}) + assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/order.go b/acme/order.go index 7b0b2d4d..a2c89fe7 100644 --- a/acme/order.go +++ b/acme/order.go @@ -20,6 +20,7 @@ type Identifier struct { // Order contains order metadata for the ACME protocol order type. type Order struct { + ID string `json:"id"` Status Status `json:"status"` Expires time.Time `json:"expires,omitempty"` Identifiers []Identifier `json:"identifiers"` @@ -31,7 +32,6 @@ type Order struct { FinalizeURL string `json:"finalize"` CertificateID string `json:"-"` CertificateURL string `json:"certificate,omitempty"` - ID string `json:"-"` AccountID string `json:"-"` ProvisionerID string `json:"-"` DefaultDuration time.Duration `json:"-"` @@ -50,7 +50,7 @@ func (o *Order) ToLog() (interface{}, error) { // UpdateStatus updates the ACME Order Status if necessary. // Changes to the order are saved using the database interface. func (o *Order) UpdateStatus(ctx context.Context, db DB) error { - now := time.Now().UTC() + now := clock.Now() switch o.Status { case StatusInvalid: From 074ab7b2217c08322f6f5668fc4d163ebf592798 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 15 Mar 2021 10:30:12 -0700 Subject: [PATCH 20/47] [acme db interface] add linker tests --- acme/account.go | 10 +- acme/api/account_test.go | 16 +-- acme/api/linker.go | 2 +- acme/api/linker_test.go | 211 ++++++++++++++++++++++++++++++++++++++- ca/acmeClient.go | 4 +- ca/acmeClient_test.go | 26 ++--- 6 files changed, 236 insertions(+), 33 deletions(-) diff --git a/acme/account.go b/acme/account.go index cb60e21d..3b6bafed 100644 --- a/acme/account.go +++ b/acme/account.go @@ -11,11 +11,11 @@ import ( // Account is a subset of the internal account type containing only those // attributes required for responses in the ACME protocol. type Account struct { - Contact []string `json:"contact,omitempty"` - Status Status `json:"status"` - Orders string `json:"orders"` - ID string `json:"-"` - Key *jose.JSONWebKey `json:"-"` + Contact []string `json:"contact,omitempty"` + Status Status `json:"status"` + OrdersURL string `json:"orders"` + ID string `json:"-"` + Key *jose.JSONWebKey `json:"-"` } // ToLog enables response logging. diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 28abffe1..918c31c5 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -420,11 +420,11 @@ func TestHandler_NewAccount(t *testing.T) { }, }, acc: &acme.Account{ - ID: "accountID", - Key: jwk, - Status: acme.StatusValid, - Contact: []string{"foo", "bar"}, - Orders: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders", + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + OrdersURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders", }, ctx: ctx, statusCode: 201, @@ -496,9 +496,9 @@ func TestHandler_NewAccount(t *testing.T) { func TestHandler_GetUpdateAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ - ID: accID, - Status: "valid", - Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), + ID: accID, + Status: "valid", + OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), } prov := newProv() provName := url.PathEscape(prov.GetName()) diff --git a/acme/api/linker.go b/acme/api/linker.go index 6688732d..b5995852 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -160,7 +160,7 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { // LinkAccount sets the ACME links required by an ACME account. func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { - acc.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) + acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go index ab1ad3ba..c3075a1a 100644 --- a/acme/api/linker_test.go +++ b/acme/api/linker_test.go @@ -7,9 +7,10 @@ import ( "testing" "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" ) -func TestLinkerGetLink(t *testing.T) { +func TestLinker_GetLink(t *testing.T) { dns := "ca.smallstep.com" prefix := "acme" linker := NewLinker(dns, prefix) @@ -42,7 +43,7 @@ func TestLinkerGetLink(t *testing.T) { assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName)) } -func TestLinkerGetLinkExplicit(t *testing.T) { +func TestLinker_GetLinkExplicit(t *testing.T) { dns := "ca.smallstep.com" baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prefix := "acme" @@ -91,9 +92,211 @@ func TestLinkerGetLinkExplicit(t *testing.T) { assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) - assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) - assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) } + +func TestLinker_LinkOrder(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + oid := "orderID" + certID := "certID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + o *acme.Order + validate func(o *acme.Order) + } + var tests = map[string]test{ + "no-authz-and-no-cert": { + o: &acme.Order{ + ID: oid, + }, + validate: func(o *acme.Order) { + assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) + assert.Equals(t, o.AuthorizationURLs, []string{}) + assert.Equals(t, o.CertificateURL, "") + }, + }, + "one-authz-and-cert": { + o: &acme.Order{ + ID: oid, + CertificateID: certID, + AuthorizationIDs: []string{"foo"}, + }, + validate: func(o *acme.Order) { + assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) + assert.Equals(t, o.AuthorizationURLs, []string{ + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), + }) + assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID)) + }, + }, + "many-authz": { + o: &acme.Order{ + ID: oid, + CertificateID: certID, + AuthorizationIDs: []string{"foo", "bar", "zap"}, + }, + validate: func(o *acme.Order) { + assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) + assert.Equals(t, o.AuthorizationURLs, []string{ + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"), + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"), + }) + assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkOrder(ctx, tc.o) + tc.validate(tc.o) + }) + } +} + +func TestLinker_LinkAccount(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + accID := "accountID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + a *acme.Account + validate func(o *acme.Account) + } + var tests = map[string]test{ + "ok": { + a: &acme.Account{ + ID: accID, + }, + validate: func(a *acme.Account) { + assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkAccount(ctx, tc.a) + tc.validate(tc.a) + }) + } +} + +func TestLinker_LinkChallenge(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + chID := "chID" + azID := "azID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + ch *acme.Challenge + validate func(o *acme.Challenge) + } + var tests = map[string]test{ + "ok": { + ch: &acme.Challenge{ + ID: chID, + AuthzID: azID, + }, + validate: func(ch *acme.Challenge) { + assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, ch.AuthzID, ch.ID)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkChallenge(ctx, tc.ch) + tc.validate(tc.ch) + }) + } +} + +func TestLinker_LinkAuthorization(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + chID0 := "chID-0" + chID1 := "chID-1" + chID2 := "chID-2" + azID := "azID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + az *acme.Authorization + validate func(o *acme.Authorization) + } + var tests = map[string]test{ + "ok": { + az: &acme.Authorization{ + ID: azID, + Challenges: []*acme.Challenge{ + {ID: chID0, AuthzID: azID}, + {ID: chID1, AuthzID: azID}, + {ID: chID2, AuthzID: azID}, + }, + }, + validate: func(az *acme.Authorization) { + assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) + assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) + assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkAuthorization(ctx, tc.az) + tc.validate(tc.az) + }) + } +} + +func TestLinker_LinkOrdersByAccountID(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + oids []string + } + var tests = map[string]test{ + "ok": { + oids: []string{"foo", "bar", "baz"}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkOrdersByAccountID(ctx, tc.oids) + assert.Equals(t, tc.oids, []string{ + fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"), + fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"), + fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"), + }) + }) + } +} diff --git a/ca/acmeClient.go b/ca/acmeClient.go index b19ad664..5633dac5 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -320,7 +320,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) { if c.acc == nil { return nil, errors.New("acme client not configured with account") } - resp, err := c.post(nil, c.acc.Orders, withKid(c)) + resp, err := c.post(nil, c.acc.OrdersURL, withKid(c)) if err != nil { return nil, err } @@ -330,7 +330,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) { var orders []string if err := readJSON(resp.Body, &orders); err != nil { - return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders) + return nil, errors.Wrapf(err, "error reading %s", c.acc.OrdersURL) } return orders, nil diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index 8debadde..08d4b734 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -40,9 +40,9 @@ func TestNewACMEClient(t *testing.T) { KeyChange: srv.URL + "/blorp", } acc := acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: "orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: "orders-url", } tests := map[string]func(t *testing.T) test{ "fail/client-option-error": func(t *testing.T) test { @@ -248,9 +248,9 @@ func TestACMEClient_post(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: "orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: "orders-url", } ac := &ACMEClient{ client: &http.Client{ @@ -1121,9 +1121,9 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { Key: jwk, kid: "foobar", acc: &acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: srv.URL + "/orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: srv.URL + "/orders-url", }, } @@ -1198,7 +1198,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) - assert.Equals(t, jwsURL, ac.acc.Orders) + assert.Equals(t, jwsURL, ac.acc.OrdersURL) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) @@ -1259,9 +1259,9 @@ func TestACMEClient_GetCertificate(t *testing.T) { Key: jwk, kid: "foobar", acc: &acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: srv.URL + "/orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: srv.URL + "/orders-url", }, } From 4b1dda5bb62e245cc3bc1a25cfce41f73a61dd49 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 18 Mar 2021 14:50:55 -0700 Subject: [PATCH 21/47] [acme db interface] tests --- acme/db/nosql/certificate.go | 32 +-- acme/db/nosql/certificate_test.go | 311 ++++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+), 26 deletions(-) create mode 100644 acme/db/nosql/certificate_test.go diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go index ef766843..bd46df4d 100644 --- a/acme/db/nosql/certificate.go +++ b/acme/db/nosql/certificate.go @@ -59,47 +59,27 @@ func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, if nosql.IsErrNotFound(err) { return nil, errors.Wrapf(err, "certificate %s not found", id) } else if err != nil { - return nil, errors.Wrap(err, "error loading certificate") + return nil, errors.Wrapf(err, "error loading certificate %s", id) } dbC := new(dbCert) if err := json.Unmarshal(b, dbC); err != nil { - return nil, errors.Wrap(err, "error unmarshaling certificate") + return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id) } - leaf, err := parseCert(dbC.Leaf) + certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...)) if err != nil { - return nil, errors.Wrapf(err, "error parsing leaf of ACME Certificate with ID '%s'", id) - } - - intermediates, err := parseBundle(dbC.Intermediates) - if err != nil { - return nil, errors.Wrapf(err, "error parsing intermediate bundle of ACME Certificate with ID '%s'", id) + return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id) } return &acme.Certificate{ ID: dbC.ID, AccountID: dbC.AccountID, OrderID: dbC.OrderID, - Leaf: leaf, - Intermediates: intermediates, + Leaf: certs[0], + Intermediates: certs[1:], }, nil } -func parseCert(b []byte) (*x509.Certificate, error) { - block, rest := pem.Decode(b) - if block == nil || len(rest) > 0 { - return nil, errors.New("error decoding PEM block: contains unexpected data") - } - if block.Type != "CERTIFICATE" { - return nil, errors.New("error decoding PEM: block is not a certificate bundle") - } - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, errors.Wrap(err, "error parsing x509 certificate") - } - return cert, nil -} - func parseBundle(b []byte) ([]*x509.Certificate, error) { var ( err error diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go new file mode 100644 index 00000000..83f07f5e --- /dev/null +++ b/acme/db/nosql/certificate_test.go @@ -0,0 +1,311 @@ +package nosql + +import ( + "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + + "go.step.sm/crypto/pemutil" +) + +func TestDB_CreateCertificate(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + type test struct { + db nosql.DB + cert *acme.Certificate + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + cert := &acme.Certificate{ + AccountID: "accounttID", + OrderID: "orderID", + Leaf: leaf, + Intermediates: []*x509.Certificate{inter, root}, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + assert.Equals(t, old, nil) + + dbc := new(dbCert) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.FatalError(t, err) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.ID, cert.ID) + assert.Equals(t, dbc.AccountID, cert.AccountID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created)) + return nil, false, errors.New("force") + }, + }, + cert: cert, + err: errors.New("error saving acme certificate: force"), + } + }, + "ok": func(t *testing.T) test { + cert := &acme.Certificate{ + AccountID: "accounttID", + OrderID: "orderID", + Leaf: leaf, + Intermediates: []*x509.Certificate{inter, root}, + } + var ( + id string + idPtr = &id + ) + + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + assert.Equals(t, old, nil) + + dbc := new(dbCert) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.FatalError(t, err) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.ID, cert.ID) + assert.Equals(t, dbc.AccountID, cert.AccountID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created)) + return nil, true, nil + }, + }, + _id: idPtr, + cert: cert, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateCertificate(context.Background(), tc.cert); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.ID, *tc._id) + } + } + }) + } +} + +func TestDB_GetCertificate(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + certID := "certID" + type test struct { + db nosql.DB + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + return nil, nosqldb.ErrNotFound + }, + }, + err: errors.New("certificate certID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + return nil, errors.Errorf("force") + }, + }, + err: errors.New("error loading certificate certID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + return []byte("foobar"), nil + }, + }, + err: errors.New("error unmarshaling certificate certID"), + } + }, + "fail/parseBundle-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + cert := dbCert{ + ID: certID, + AccountID: "accountID", + OrderID: "orderID", + Leaf: pem.EncodeToMemory(&pem.Block{ + Type: "Public Key", + Bytes: leaf.Raw, + }), + Created: clock.Now(), + } + b, err := json.Marshal(cert) + assert.FatalError(t, err) + + return b, nil + }, + }, + err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + cert := dbCert{ + ID: certID, + AccountID: "accountID", + OrderID: "orderID", + Leaf: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), + Intermediates: append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: inter.Raw, + }), pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: root.Raw, + })...), + Created: clock.Now(), + } + b, err := json.Marshal(cert) + assert.FatalError(t, err) + + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + cert, err := db.GetCertificate(context.Background(), certID) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, cert.ID, certID) + assert.Equals(t, cert.AccountID, "accountID") + assert.Equals(t, cert.OrderID, "orderID") + assert.Equals(t, cert.Leaf, leaf) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) + } + } + }) + } +} + +func Test_parseBundle(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + var certs []byte + for _, cert := range []*x509.Certificate{leaf, inter, root} { + certs = append(certs, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })...) + } + + type test struct { + b []byte + err error + } + var tests = map[string]test{ + "fail/bad-type-error": { + b: pem.EncodeToMemory(&pem.Block{ + Type: "Public Key", + Bytes: leaf.Raw, + }), + err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"), + }, + "fail/bad-pem-error": { + b: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: []byte("foo"), + }), + err: errors.Errorf("error parsing x509 certificate"), + }, + "fail/unexpected-data": { + b: append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), []byte("foo")...), + err: errors.Errorf("error decoding PEM: unexpected data"), + }, + "ok": { + b: certs, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ret, err := parseBundle(tc.b) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root}) + } + } + }) + } +} From 206909b12e86ae189d84833506c7a4c5696d4ac6 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 18 Mar 2021 23:08:13 -0700 Subject: [PATCH 22/47] [acme db interface] unit tests for challenge nosql db --- acme/authorization.go | 4 +- acme/challenge.go | 26 +- acme/db/nosql/authz.go | 10 +- acme/db/nosql/certificate.go | 6 +- acme/db/nosql/certificate_test.go | 40 ++- acme/db/nosql/challenge.go | 50 ++-- acme/db/nosql/challenge_test.go | 477 ++++++++++++++++++++++++++++++ acme/db/nosql/nonce.go | 62 ++-- acme/db/nosql/nonce_test.go | 209 +++++++++++++ 9 files changed, 788 insertions(+), 96 deletions(-) create mode 100644 acme/db/nosql/challenge_test.go create mode 100644 acme/db/nosql/nonce_test.go diff --git a/acme/authorization.go b/acme/authorization.go index df4ac229..cf68cba3 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -10,7 +10,7 @@ import ( type Authorization struct { Identifier Identifier `json:"identifier"` Status Status `json:"status"` - Expires time.Time `json:"expires"` + ExpiresAt time.Time `json:"expires"` Challenges []*Challenge `json:"challenges"` Wildcard bool `json:"wildcard"` ID string `json:"-"` @@ -39,7 +39,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusPending: // check expiry - if now.After(az.Expires) { + if now.After(az.ExpiresAt) { az.Status = StatusInvalid break } diff --git a/acme/challenge.go b/acme/challenge.go index 2c6f5fb1..a3514d15 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -22,16 +22,16 @@ import ( // Challenge represents an ACME response Challenge type. type Challenge struct { - Type string `json:"type"` - Status Status `json:"status"` - Token string `json:"token"` - Validated string `json:"validated,omitempty"` - URL string `json:"url"` - Error *Error `json:"error,omitempty"` - ID string `json:"-"` - AuthzID string `json:"-"` - AccountID string `json:"-"` - Value string `json:"-"` + Type string `json:"type"` + Status Status `json:"status"` + Token string `json:"token"` + ValidatedAt string `json:"validated,omitempty"` + URL string `json:"url"` + Error *Error `json:"error,omitempty"` + ID string `json:"-"` + AuthzID string `json:"-"` + AccountID string `json:"-"` + Value string `json:"-"` } // ToLog enables response logging. @@ -97,7 +97,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb // Update and store the challenge. ch.Status = StatusValid ch.Error = nil - ch.Validated = clock.Now().Format(time.RFC3339) + ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") @@ -175,7 +175,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON ch.Status = StatusValid ch.Error = nil - ch.Validated = clock.Now().Format(time.RFC3339) + ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge") @@ -231,7 +231,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK // Update and store the challenge. ch.Status = StatusValid ch.Error = nil - ch.Validated = clock.Now().UTC().Format(time.RFC3339) + ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 0992509d..a5d422a7 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -18,10 +18,10 @@ type dbAuthz struct { AccountID string `json:"accountID"` Identifier acme.Identifier `json:"identifier"` Status acme.Status `json:"status"` - Expires time.Time `json:"expires"` + ExpiresAt time.Time `json:"expiresAt"` Challenges []string `json:"challenges"` Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` + CreatedAt time.Time `json:"createdAt"` Error *acme.Error `json:"error"` } @@ -66,7 +66,7 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat Status: dbaz.Status, Challenges: chs, Wildcard: dbaz.Wildcard, - Expires: dbaz.Expires, + ExpiresAt: dbaz.ExpiresAt, ID: dbaz.ID, }, nil } @@ -90,8 +90,8 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e ID: az.ID, AccountID: az.AccountID, Status: acme.StatusPending, - Created: now, - Expires: now.Add(defaultExpiryDuration), + CreatedAt: now, + ExpiresAt: now.Add(defaultExpiryDuration), Identifier: az.Identifier, Challenges: chIDs, Wildcard: az.Wildcard, diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go index bd46df4d..d3e15833 100644 --- a/acme/db/nosql/certificate.go +++ b/acme/db/nosql/certificate.go @@ -14,7 +14,7 @@ import ( type dbCert struct { ID string `json:"id"` - Created time.Time `json:"created"` + CreatedAt time.Time `json:"createdAt"` AccountID string `json:"accountID"` OrderID string `json:"orderID"` Leaf []byte `json:"leaf"` @@ -47,7 +47,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err OrderID: cert.OrderID, Leaf: leaf, Intermediates: intermediates, - Created: time.Now().UTC(), + CreatedAt: time.Now().UTC(), } return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) } @@ -57,7 +57,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) { b, err := db.db.Get(certTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "certificate %s not found", id) + return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading certificate %s", id) } diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go index 83f07f5e..4ec4589e 100644 --- a/acme/db/nosql/certificate_test.go +++ b/acme/db/nosql/certificate_test.go @@ -34,7 +34,7 @@ func TestDB_CreateCertificate(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { cert := &acme.Certificate{ - AccountID: "accounttID", + AccountID: "accountID", OrderID: "orderID", Leaf: leaf, Intermediates: []*x509.Certificate{inter, root}, @@ -48,12 +48,11 @@ func TestDB_CreateCertificate(t *testing.T) { dbc := new(dbCert) assert.FatalError(t, json.Unmarshal(nu, dbc)) - assert.FatalError(t, err) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.ID, cert.ID) assert.Equals(t, dbc.AccountID, cert.AccountID) - assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created)) - assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, false, errors.New("force") }, }, @@ -63,7 +62,7 @@ func TestDB_CreateCertificate(t *testing.T) { }, "ok": func(t *testing.T) test { cert := &acme.Certificate{ - AccountID: "accounttID", + AccountID: "accountID", OrderID: "orderID", Leaf: leaf, Intermediates: []*x509.Certificate{inter, root}, @@ -83,12 +82,11 @@ func TestDB_CreateCertificate(t *testing.T) { dbc := new(dbCert) assert.FatalError(t, json.Unmarshal(nu, dbc)) - assert.FatalError(t, err) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.ID, cert.ID) assert.Equals(t, dbc.AccountID, cert.AccountID) - assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created)) - assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, true, nil }, }, @@ -124,8 +122,9 @@ func TestDB_GetCertificate(t *testing.T) { certID := "certID" type test struct { - db nosql.DB - err error + db nosql.DB + err error + acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { @@ -138,7 +137,7 @@ func TestDB_GetCertificate(t *testing.T) { return nil, nosqldb.ErrNotFound }, }, - err: errors.New("certificate certID not found"), + acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { @@ -182,7 +181,7 @@ func TestDB_GetCertificate(t *testing.T) { Type: "Public Key", Bytes: leaf.Raw, }), - Created: clock.Now(), + CreatedAt: clock.Now(), } b, err := json.Marshal(cert) assert.FatalError(t, err) @@ -215,7 +214,7 @@ func TestDB_GetCertificate(t *testing.T) { Type: "CERTIFICATE", Bytes: root.Raw, })...), - Created: clock.Now(), + CreatedAt: clock.Now(), } b, err := json.Marshal(cert) assert.FatalError(t, err) @@ -232,8 +231,19 @@ func TestDB_GetCertificate(t *testing.T) { db := DB{db: tc.db} cert, err := db.GetCertificate(context.Background(), certID) if err != nil { - if assert.NotNil(t, tc.err) { - assert.HasPrefix(t, err.Error(), tc.err.Error()) + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } } } else { if assert.Nil(t, tc.err) { diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index 48340cf4..afcb4600 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -10,18 +10,17 @@ import ( "github.com/smallstep/nosql" ) -// dbChallenge is the base Challenge type that others build from. type dbChallenge struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - AuthzID string `json:"authzID"` - Type string `json:"type"` - Status acme.Status `json:"status"` - Token string `json:"token"` - Value string `json:"value"` - Validated string `json:"validated"` - Created time.Time `json:"created"` - Error *acme.Error `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + AuthzID string `json:"authzID"` + Type string `json:"type"` + Status acme.Status `json:"status"` + Token string `json:"token"` + Value string `json:"value"` + ValidatedAt string `json:"validatedAt"` + CreatedAt time.Time `json:"createdAt"` + Error *acme.Error `json:"error"` } func (dbc *dbChallenge) clone() *dbChallenge { @@ -32,9 +31,9 @@ func (dbc *dbChallenge) clone() *dbChallenge { func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { data, err := db.db.Get(challengeTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "challenge %s not found", id) + return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id) } else if err != nil { - return nil, errors.Wrapf(err, "error loading challenge %s", id) + return nil, errors.Wrapf(err, "error loading acme challenge %s", id) } dbch := new(dbChallenge) @@ -60,7 +59,7 @@ func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error { Value: ch.Value, Status: acme.StatusPending, Token: ch.Token, - Created: clock.Now(), + CreatedAt: clock.Now(), Type: ch.Type, } @@ -76,22 +75,21 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall } ch := &acme.Challenge{ - Type: dbch.Type, - Status: dbch.Status, - Token: dbch.Token, - ID: dbch.ID, - AuthzID: dbch.AuthzID, - Error: dbch.Error, - Validated: dbch.Validated, + ID: dbch.ID, + AccountID: dbch.AccountID, + AuthzID: dbch.AuthzID, + Type: dbch.Type, + Value: dbch.Value, + Status: dbch.Status, + Token: dbch.Token, + Error: dbch.Error, + ValidatedAt: dbch.ValidatedAt, } return ch, nil } // UpdateChallenge updates an ACME challenge type in the database. func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error { - if len(ch.ID) == 0 { - return errors.New("id cannot be empty") - } old, err := db.getDBChallenge(ctx, ch.ID) if err != nil { return err @@ -99,10 +97,10 @@ func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error { nu := old.clone() - // These should be the only values chaning in an Update request. + // These should be the only values changing in an Update request. nu.Status = ch.Status nu.Error = ch.Error - nu.Validated = ch.Validated + nu.ValidatedAt = ch.ValidatedAt return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) } diff --git a/acme/db/nosql/challenge_test.go b/acme/db/nosql/challenge_test.go new file mode 100644 index 00000000..34af74ce --- /dev/null +++ b/acme/db/nosql/challenge_test.go @@ -0,0 +1,477 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBChallenge(t *testing.T) { + chID := "chID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbc *dbChallenge + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading acme challenge chID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling dbChallenge"), + } + }, + "ok": func(t *testing.T) test { + dbc := &dbChallenge{ + ID: chID, + AccountID: "accountID", + AuthzID: "authzID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + CreatedAt: clock.Now(), + ValidatedAt: "foobar", + Error: acme.NewErrorISE("force"), + } + b, err := json.Marshal(dbc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + }, + dbc: dbc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if ch, err := db.getDBChallenge(context.Background(), chID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err, tc.acmeErr.Err) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.ID, tc.dbc.ID) + assert.Equals(t, ch.AccountID, tc.dbc.AccountID) + assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID) + assert.Equals(t, ch.Type, tc.dbc.Type) + assert.Equals(t, ch.Status, tc.dbc.Status) + assert.Equals(t, ch.Token, tc.dbc.Token) + assert.Equals(t, ch.Value, tc.dbc.Value) + assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) + assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) + } + } + }) + } +} + +func TestDB_CreateChallenge(t *testing.T) { + type test struct { + db nosql.DB + ch *acme.Challenge + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + ch := &acme.Challenge{ + AccountID: "accountID", + AuthzID: "authzID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), ch.ID) + assert.Equals(t, old, nil) + + dbc := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.AccountID, ch.AccountID) + assert.Equals(t, dbc.AuthzID, ch.AuthzID) + assert.Equals(t, dbc.Type, ch.Type) + assert.Equals(t, dbc.Status, ch.Status) + assert.Equals(t, dbc.Token, ch.Token) + assert.Equals(t, dbc.Value, ch.Value) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + ch: ch, + err: errors.New("error saving acme challenge: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ch = &acme.Challenge{ + AccountID: "accountID", + AuthzID: "authzID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + } + ) + + return test{ + ch: ch, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), ch.ID) + assert.Equals(t, old, nil) + + dbc := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.AccountID, ch.AccountID) + assert.Equals(t, dbc.AuthzID, ch.AuthzID) + assert.Equals(t, dbc.Type, ch.Type) + assert.Equals(t, dbc.Status, ch.Status) + assert.Equals(t, dbc.Token, ch.Token) + assert.Equals(t, dbc.Value, ch.Value) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + return nil, true, nil + }, + }, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateChallenge(context.Background(), tc.ch); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.ID, *tc._id) + } + } + }) + } +} + +func TestDB_GetChallenge(t *testing.T) { + chID := "chID" + azID := "azID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbc *dbChallenge + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading acme challenge chID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"), + } + }, + "ok": func(t *testing.T) test { + dbc := &dbChallenge{ + ID: chID, + AccountID: "accountID", + AuthzID: azID, + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + CreatedAt: clock.Now(), + ValidatedAt: "foobar", + Error: acme.NewErrorISE("force"), + } + b, err := json.Marshal(dbc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + }, + dbc: dbc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.ID, tc.dbc.ID) + assert.Equals(t, ch.AccountID, tc.dbc.AccountID) + assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID) + assert.Equals(t, ch.Type, tc.dbc.Type) + assert.Equals(t, ch.Status, tc.dbc.Status) + assert.Equals(t, ch.Token, tc.dbc.Token) + assert.Equals(t, ch.Value, tc.dbc.Value) + assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) + assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) + } + } + }) + } +} + +func TestDB_UpdateChallenge(t *testing.T) { + chID := "chID" + dbc := &dbChallenge{ + ID: chID, + AccountID: "accountID", + AuthzID: "azID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + CreatedAt: clock.Now(), + } + b, err := json.Marshal(dbc) + assert.FatalError(t, err) + type test struct { + db nosql.DB + ch *acme.Challenge + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ch: &acme.Challenge{ + ID: chID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading acme challenge chID: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + updCh := &acme.Challenge{ + ID: chID, + Status: acme.StatusValid, + ValidatedAt: "foobar", + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + ch: updCh, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, old, b) + + dbOld := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbc, dbOld) + + dbNew := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbc.ID) + assert.Equals(t, dbNew.AccountID, dbc.AccountID) + assert.Equals(t, dbNew.AuthzID, dbc.AuthzID) + assert.Equals(t, dbNew.Type, dbc.Type) + assert.Equals(t, dbNew.Status, updCh.Status) + assert.Equals(t, dbNew.Token, dbc.Token) + assert.Equals(t, dbNew.Value, dbc.Value) + assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error()) + assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt) + assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme challenge: force"), + } + }, + "ok": func(t *testing.T) test { + updCh := &acme.Challenge{ + ID: dbc.ID, + AccountID: dbc.AccountID, + AuthzID: dbc.AuthzID, + Type: dbc.Type, + Token: dbc.Token, + Value: dbc.Value, + Status: acme.StatusValid, + ValidatedAt: "foobar", + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + ch: updCh, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, old, b) + + dbOld := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbc, dbOld) + + dbNew := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbc.ID) + assert.Equals(t, dbNew.AccountID, dbc.AccountID) + assert.Equals(t, dbNew.AuthzID, dbc.AuthzID) + assert.Equals(t, dbNew.Type, dbc.Type) + assert.Equals(t, dbNew.Token, dbc.Token) + assert.Equals(t, dbNew.Value, dbc.Value) + assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt) + assert.Equals(t, dbNew.Status, acme.StatusValid) + assert.Equals(t, dbNew.ValidatedAt, "foobar") + assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.ID, dbc.ID) + assert.Equals(t, tc.ch.AccountID, dbc.AccountID) + assert.Equals(t, tc.ch.AuthzID, dbc.AuthzID) + assert.Equals(t, tc.ch.Type, dbc.Type) + assert.Equals(t, tc.ch.Token, dbc.Token) + assert.Equals(t, tc.ch.Value, dbc.Value) + assert.Equals(t, tc.ch.ValidatedAt, "foobar") + assert.Equals(t, tc.ch.Status, acme.StatusValid) + assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + } + } + }) + } +} diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index 76f742b2..4f587ae0 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -8,14 +8,19 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" - nosqlDB "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" + "github.com/smallstep/nosql" ) // dbNonce contains nonce metadata used in the ACME protocol. type dbNonce struct { - ID string - Created time.Time + ID string + CreatedAt time.Time + DeletedAt time.Time +} + +func (dbn *dbNonce) clone() *dbNonce { + u := *dbn + return &u } // CreateNonce creates, stores, and returns an ACME replay-nonce. @@ -28,14 +33,10 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { id := base64.RawURLEncoding.EncodeToString([]byte(_id)) n := &dbNonce{ - ID: id, - Created: clock.Now(), - } - b, err := json.Marshal(n) - if err != nil { - return "", errors.Wrap(err, "error marshaling nonce") + ID: id, + CreatedAt: clock.Now(), } - if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil { + if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { return "", err } return acme.Nonce(id), nil @@ -44,27 +45,24 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { // DeleteNonce verifies that the nonce is valid (by checking if it exists), // and if so, consumes the nonce resource by deleting it from the database. func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error { - err := db.db.Update(&database.Tx{ - Operations: []*database.TxEntry{ - { - Bucket: nonceTable, - Key: []byte(nonce), - Cmd: database.Get, - }, - { - Bucket: nonceTable, - Key: []byte(nonce), - Cmd: database.Delete, - }, - }, - }) + id := string(nonce) + b, err := db.db.Get(nonceTable, []byte(nonce)) + if nosql.IsErrNotFound(err) { + return errors.Wrapf(err, "nonce %s not found", id) + } else if err != nil { + return errors.Wrapf(err, "error loading nonce %s", id) + } - switch { - case nosqlDB.IsErrNotFound(err): - return errors.New("not found") - case err != nil: - return errors.Wrapf(err, "error deleting nonce %s", nonce) - default: - return nil + dbn := new(dbNonce) + if err := json.Unmarshal(b, dbn); err != nil { + return errors.Wrapf(err, "error unmarshaling nonce %s", string(nonce)) + } + if !dbn.DeletedAt.IsZero() { + return acme.NewError(acme.ErrorBadNonceType, "nonce %s already deleted", id) } + + nu := dbn.clone() + nu.DeletedAt = clock.Now() + + return db.save(ctx, id, nu, dbn, "nonce", nonceTable) } diff --git a/acme/db/nosql/nonce_test.go b/acme/db/nosql/nonce_test.go new file mode 100644 index 00000000..1159ec00 --- /dev/null +++ b/acme/db/nosql/nonce_test.go @@ -0,0 +1,209 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_CreateNonce(t *testing.T) { + type test struct { + db nosql.DB + nonce *acme.Nonce + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + + dbn := new(dbNonce) + assert.FatalError(t, json.Unmarshal(nu, dbn)) + assert.Equals(t, dbn.ID, string(key)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme nonce: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ) + + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + + dbn := new(dbNonce) + assert.FatalError(t, json.Unmarshal(nu, dbn)) + assert.Equals(t, dbn.ID, string(key)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt)) + return nil, true, nil + }, + }, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if n, err := db.CreateNonce(context.Background()); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, string(n), *tc._id) + } + } + }) + } +} + +func TestDB_DeleteNonce(t *testing.T) { + + nonceID := "nonceID" + type test struct { + db nosql.DB + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, string(key), nonceID) + + return nil, nosqldb.ErrNotFound + }, + }, + err: errors.New("nonce nonceID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, string(key), nonceID) + + return nil, errors.Errorf("force") + }, + }, + err: errors.New("error loading nonce nonceID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, string(key), nonceID) + + a := []string{"foo", "bar", "baz"} + b, err := json.Marshal(a) + assert.FatalError(t, err) + + return b, nil + }, + }, + err: errors.New("error unmarshaling nonce nonceID"), + } + }, + "fail/already-used": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, string(key), nonceID) + + nonce := dbNonce{ + ID: nonceID, + CreatedAt: clock.Now().Add(-5 * time.Minute), + DeletedAt: clock.Now(), + } + b, err := json.Marshal(nonce) + assert.FatalError(t, err) + + return b, nil + }, + }, + err: acme.NewError(acme.ErrorBadNonceType, "nonce already deleted"), + } + }, + "ok": func(t *testing.T) test { + nonce := dbNonce{ + ID: nonceID, + CreatedAt: clock.Now().Add(-5 * time.Minute), + } + b, err := json.Marshal(nonce) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, string(key), nonceID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, b) + + dbo := new(dbNonce) + assert.FatalError(t, json.Unmarshal(old, dbo)) + assert.Equals(t, dbo.ID, string(key)) + assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbo.CreatedAt)) + assert.True(t, clock.Now().Add(-4*time.Minute).After(dbo.CreatedAt)) + assert.True(t, dbo.DeletedAt.IsZero()) + + dbn := new(dbNonce) + assert.FatalError(t, json.Unmarshal(nu, dbn)) + assert.Equals(t, dbn.ID, string(key)) + assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbn.CreatedAt)) + assert.True(t, clock.Now().Add(-4*time.Minute).After(dbn.CreatedAt)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.DeletedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbn.DeletedAt)) + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} From f72b2ff2c2d39cb223a74c49be9e33c213badae5 Mon Sep 17 00:00:00 2001 From: max furman Date: Fri, 19 Mar 2021 14:37:45 -0700 Subject: [PATCH 23/47] [acme db interface] nosql authz unit tests --- acme/api/handler_test.go | 8 +- acme/api/order.go | 15 +- acme/api/order_test.go | 20 +- acme/authorization.go | 1 + acme/db/nosql/authz.go | 19 +- acme/db/nosql/authz_test.go | 620 ++++++++++++++++++++++++++++++++++++ acme/db/nosql/order.go | 4 +- acme/order.go | 6 +- ca/acmeClient_test.go | 8 +- 9 files changed, 667 insertions(+), 34 deletions(-) create mode 100644 acme/db/nosql/authz_test.go diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 70d2dc14..7fd8e110 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -121,9 +121,9 @@ func TestHandler_GetAuthorization(t *testing.T) { Type: "dns", Value: "example.com", }, - Status: "pending", - Expires: expiry, - Wildcard: false, + Status: "pending", + ExpiresAt: expiry, + Wildcard: false, Challenges: []*acme.Challenge{ { Type: "http-01", @@ -220,7 +220,7 @@ func TestHandler_GetAuthorization(t *testing.T) { return &acme.Authorization{ AccountID: "accID", Status: acme.StatusPending, - Expires: time.Now().Add(-1 * time.Hour), + ExpiresAt: time.Now().Add(-1 * time.Hour), }, nil }, MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error { diff --git a/acme/api/order.go b/acme/api/order.go index 9fe0eb26..379c2287 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -89,14 +89,24 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } + now := clock.Now() + expiry := now.Add(defaultOrderExpiry) // New order. - o := &acme.Order{Identifiers: nor.Identifiers} + o := &acme.Order{ + AccountID: acc.ID, + ProvisionerID: prov.GetID(), + Status: acme.StatusPending, + ExpiresAt: expiry, + Identifiers: nor.Identifiers, + } o.AuthorizationIDs = make([]string, len(o.Identifiers)) for i, identifier := range o.Identifiers { az := &acme.Authorization{ AccountID: acc.ID, Identifier: identifier, + ExpiresAt: expiry, + Status: acme.StatusPending, } if err := h.newAuthorization(ctx, az); err != nil { api.WriteError(w, err) @@ -105,14 +115,12 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { o.AuthorizationIDs[i] = az.ID } - now := clock.Now() if o.NotBefore.IsZero() { o.NotBefore = now } if o.NotAfter.IsZero() { o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) } - o.Expires = now.Add(defaultOrderExpiry) if err := h.db.CreateOrder(ctx, o); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) @@ -156,6 +164,7 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) Value: az.Identifier.Value, Type: typ, Token: az.Token, + Status: acme.StatusPending, } if err := h.db.CreateChallenge(ctx, ch); err != nil { return err diff --git a/acme/api/order_test.go b/acme/api/order_test.go index b6783e34..0bc3caab 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -166,9 +166,9 @@ func TestHandler_GetOrder(t *testing.T) { Value: "*.smallstep.com", }, }, - Expires: expiry, - Status: acme.StatusInvalid, - Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), + ExpiresAt: expiry, + Status: acme.StatusInvalid, + Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), AuthorizationURLs: []string{ "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", @@ -285,7 +285,7 @@ func TestHandler_GetOrder(t *testing.T) { return &acme.Order{ AccountID: "accountID", ProvisionerID: "acme/test@acme-provisioner.com", - Expires: clock.Now().Add(-time.Hour), + ExpiresAt: clock.Now().Add(-time.Hour), Status: acme.StatusReady, }, nil }, @@ -311,7 +311,7 @@ func TestHandler_GetOrder(t *testing.T) { ID: "orderID", AccountID: "accountID", ProvisionerID: "acme/test@acme-provisioner.com", - Expires: expiry, + ExpiresAt: expiry, Status: acme.StatusReady, AuthorizationIDs: []string{"foo", "bar", "baz"}, NotBefore: nbf, @@ -380,7 +380,7 @@ func TestHandler_NewOrder(t *testing.T) { naf := nbf.Add(17 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry, + ExpiresAt: expiry, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ @@ -607,8 +607,8 @@ func TestHandler_FinalizeOrder(t *testing.T) { Value: "*.smallstep.com", }, }, - Expires: naf, - Status: acme.StatusValid, + ExpiresAt: naf, + Status: acme.StatusValid, AuthorizationURLs: []string{ "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", @@ -788,7 +788,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { return &acme.Order{ AccountID: "accountID", ProvisionerID: "acme/test@acme-provisioner.com", - Expires: clock.Now().Add(-time.Hour), + ExpiresAt: clock.Now().Add(-time.Hour), Status: acme.StatusReady, }, nil }, @@ -815,7 +815,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { ID: "orderID", AccountID: "accountID", ProvisionerID: "acme/test@acme-provisioner.com", - Expires: naf, + ExpiresAt: naf, Status: acme.StatusValid, AuthorizationIDs: []string{"foo", "bar", "baz"}, NotBefore: nbf, diff --git a/acme/authorization.go b/acme/authorization.go index cf68cba3..62bc4637 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -13,6 +13,7 @@ type Authorization struct { ExpiresAt time.Time `json:"expires"` Challenges []*Challenge `json:"challenges"` Wildcard bool `json:"wildcard"` + Error *Error `json:"error,omitempty"` ID string `json:"-"` AccountID string `json:"-"` Token string `json:"-"` diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index a5d422a7..2ea1bb69 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -23,6 +23,7 @@ type dbAuthz struct { Wildcard bool `json:"wildcard"` CreatedAt time.Time `json:"createdAt"` Error *acme.Error `json:"error"` + Token string `json:"token"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -35,14 +36,14 @@ func (ba *dbAuthz) clone() *dbAuthz { func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) { data, err := db.db.Get(authzTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "authz %s not found", id) + return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading authz %s", id) } var dbaz dbAuthz if err = json.Unmarshal(data, &dbaz); err != nil { - return nil, errors.Wrap(err, "error unmarshaling authz type into dbAuthz") + return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id) } return &dbaz, nil } @@ -62,12 +63,15 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat } } return &acme.Authorization{ + ID: dbaz.ID, + AccountID: dbaz.AccountID, Identifier: dbaz.Identifier, Status: dbaz.Status, Challenges: chs, Wildcard: dbaz.Wildcard, ExpiresAt: dbaz.ExpiresAt, - ID: dbaz.ID, + Token: dbaz.Token, + Error: dbaz.Error, }, nil } @@ -89,11 +93,12 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e dbaz := &dbAuthz{ ID: az.ID, AccountID: az.AccountID, - Status: acme.StatusPending, + Status: az.Status, CreatedAt: now, - ExpiresAt: now.Add(defaultExpiryDuration), + ExpiresAt: az.ExpiresAt, Identifier: az.Identifier, Challenges: chIDs, + Token: az.Token, Wildcard: az.Wildcard, } @@ -102,9 +107,6 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e // UpdateAuthorization saves an updated ACME Authorization to the database. func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error { - if len(az.ID) == 0 { - return errors.New("id cannot be empty") - } old, err := db.getDBAuthz(ctx, az.ID) if err != nil { return err @@ -113,5 +115,6 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e nu := old.clone() nu.Status = az.Status + nu.Error = az.Error return db.save(ctx, old.ID, nu, old, "authz", authzTable) } diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go new file mode 100644 index 00000000..825c4648 --- /dev/null +++ b/acme/db/nosql/authz_test.go @@ -0,0 +1,620 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBAuthz(t *testing.T) { + azID := "azID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbaz *dbAuthz + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authz azID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling authz azID into dbAuthz"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + Challenges: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return b, nil + }, + }, + dbaz: dbaz, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbaz.ID, tc.dbaz.ID) + assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) + assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) + assert.Equals(t, dbaz.Status, tc.dbaz.Status) + assert.Equals(t, dbaz.Token, tc.dbaz.Token) + assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) + assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) + assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) + assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard) + } + } + }) + } +} + +func TestDB_GetAuthorization(t *testing.T) { + azID := "azID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbaz *dbAuthz + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authz azID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"), + } + }, + "fail/db.GetChallenge-error": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + Challenges: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + assert.Equals(t, string(key), azID) + return b, nil + case string(challengeTable): + assert.Equals(t, string(key), "foo") + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) + return nil, errors.New("force") + } + }, + }, + err: errors.New("error loading acme challenge foo: force"), + } + }, + "fail/db.GetChallenge-not-found": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + Challenges: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + assert.Equals(t, string(key), azID) + return b, nil + case string(challengeTable): + assert.Equals(t, string(key), "foo") + return nil, nosqldb.ErrNotFound + default: + assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + Challenges: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + chCount := 0 + fooChb, err := json.Marshal(&dbChallenge{ID: "foo"}) + assert.FatalError(t, err) + barChb, err := json.Marshal(&dbChallenge{ID: "bar"}) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + assert.Equals(t, string(key), azID) + return b, nil + case string(challengeTable): + if chCount == 0 { + chCount++ + assert.Equals(t, string(key), "foo") + return fooChb, nil + } + assert.Equals(t, string(key), "bar") + return barChb, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) + return nil, errors.New("force") + } + }, + }, + dbaz: dbaz, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if az, err := db.GetAuthorization(context.Background(), azID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, az.ID, tc.dbaz.ID) + assert.Equals(t, az.AccountID, tc.dbaz.AccountID) + assert.Equals(t, az.Identifier, tc.dbaz.Identifier) + assert.Equals(t, az.Status, tc.dbaz.Status) + assert.Equals(t, az.Token, tc.dbaz.Token) + assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) + assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) + assert.Equals(t, az.Challenges, []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }) + assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error()) + } + } + }) + } +} + +func TestDB_CreateAuthorization(t *testing.T) { + azID := "azID" + type test struct { + db nosql.DB + az *acme.Authorization + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + now := clock.Now() + az := &acme.Authorization{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }, + Wildcard: true, + Error: acme.NewErrorISE("force"), + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), az.ID) + assert.Equals(t, old, nil) + + dbaz := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbaz)) + assert.Equals(t, dbaz.ID, string(key)) + assert.Equals(t, dbaz.AccountID, az.AccountID) + assert.Equals(t, dbaz.Identifier, acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }) + assert.Equals(t, dbaz.Status, az.Status) + assert.Equals(t, dbaz.Token, az.Token) + assert.Equals(t, dbaz.Challenges, []string{"foo", "bar"}) + assert.Equals(t, dbaz.Wildcard, az.Wildcard) + assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) + assert.Nil(t, dbaz.Error) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + az: az, + err: errors.New("error saving acme authz: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + now = clock.Now() + az = &acme.Authorization{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }, + Wildcard: true, + Error: acme.NewErrorISE("force"), + } + ) + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), az.ID) + assert.Equals(t, old, nil) + + dbaz := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbaz)) + assert.Equals(t, dbaz.ID, string(key)) + assert.Equals(t, dbaz.AccountID, az.AccountID) + assert.Equals(t, dbaz.Identifier, acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }) + assert.Equals(t, dbaz.Status, az.Status) + assert.Equals(t, dbaz.Token, az.Token) + assert.Equals(t, dbaz.Challenges, []string{"foo", "bar"}) + assert.Equals(t, dbaz.Wildcard, az.Wildcard) + assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) + assert.Nil(t, dbaz.Error) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt)) + return nu, true, nil + }, + }, + az: az, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateAuthorization(context.Background(), tc.az); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.az.ID, *tc._id) + } + } + }) + } +} + +func TestDB_UpdateAuthorization(t *testing.T) { + azID := "azID" + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + type test struct { + db nosql.DB + az *acme.Authorization + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + az: &acme.Authorization{ + ID: azID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authz azID: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + updAz := &acme.Authorization{ + ID: azID, + Status: acme.StatusValid, + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + az: updAz, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, old, b) + + dbOld := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbaz, dbOld) + + dbNew := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbaz.ID) + assert.Equals(t, dbNew.AccountID, dbaz.AccountID) + assert.Equals(t, dbNew.Identifier, dbaz.Identifier) + assert.Equals(t, dbNew.Status, acme.StatusValid) + assert.Equals(t, dbNew.Token, dbaz.Token) + assert.Equals(t, dbNew.Challenges, dbaz.Challenges) + assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) + assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) + assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme authz: force"), + } + }, + "ok": func(t *testing.T) test { + updAz := &acme.Authorization{ + ID: azID, + AccountID: dbaz.AccountID, + Status: acme.StatusValid, + Identifier: dbaz.Identifier, + Challenges: []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }, + Token: dbaz.Token, + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + az: updAz, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, old, b) + + dbOld := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbaz, dbOld) + + dbNew := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbaz.ID) + assert.Equals(t, dbNew.AccountID, dbaz.AccountID) + assert.Equals(t, dbNew.Identifier, dbaz.Identifier) + assert.Equals(t, dbNew.Status, acme.StatusValid) + assert.Equals(t, dbNew.Token, dbaz.Token) + assert.Equals(t, dbNew.Challenges, dbaz.Challenges) + assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) + assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) + assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.az.ID, dbaz.ID) + assert.Equals(t, tc.az.AccountID, dbaz.AccountID) + assert.Equals(t, tc.az.Identifier, dbaz.Identifier) + assert.Equals(t, tc.az.Status, acme.StatusValid) + assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard) + assert.Equals(t, tc.az.Token, dbaz.Token) + assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt) + assert.Equals(t, tc.az.Challenges, []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }) + assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + } + } + }) + } +} diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 59afc41c..bc89442a 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -58,7 +58,7 @@ func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { o := &acme.Order{ Status: dbo.Status, - Expires: dbo.Expires, + ExpiresAt: dbo.Expires, Identifiers: dbo.Identifiers, NotBefore: dbo.NotBefore, NotAfter: dbo.NotAfter, @@ -86,7 +86,7 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { ProvisionerID: o.ProvisionerID, Created: now, Status: acme.StatusPending, - Expires: o.Expires, + Expires: o.ExpiresAt, Identifiers: o.Identifiers, NotBefore: o.NotBefore, NotAfter: o.NotBefore, diff --git a/acme/order.go b/acme/order.go index a2c89fe7..f62e3354 100644 --- a/acme/order.go +++ b/acme/order.go @@ -22,7 +22,7 @@ type Identifier struct { type Order struct { ID string `json:"id"` Status Status `json:"status"` - Expires time.Time `json:"expires,omitempty"` + ExpiresAt time.Time `json:"expires,omitempty"` Identifiers []Identifier `json:"identifiers"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` @@ -59,7 +59,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusReady: // Check expiry - if now.After(o.Expires) { + if now.After(o.ExpiresAt) { o.Status = StatusInvalid o.Error = NewError(ErrorMalformedType, "order has expired") break @@ -67,7 +67,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusPending: // Check expiry - if now.After(o.Expires) { + if now.After(o.ExpiresAt) { o.Status = StatusInvalid o.Error = NewError(ErrorMalformedType, "order has expired") break diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index 08d4b734..3fbd42c5 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -388,7 +388,7 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) ord := acme.Order{ Status: "valid", - Expires: time.Now(), // "soon" + ExpiresAt: time.Now(), // "soon" FinalizeURL: "finalize-url", } ac := &ACMEClient{ @@ -510,7 +510,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) ord := acme.Order{ Status: "valid", - Expires: time.Now(), // "soon" + ExpiresAt: time.Now(), // "soon" FinalizeURL: "finalize-url", } ac := &ACMEClient{ @@ -630,7 +630,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) az := acme.Authorization{ Status: "valid", - Expires: time.Now(), + ExpiresAt: time.Now(), Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, } ac := &ACMEClient{ @@ -988,7 +988,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) ord := acme.Order{ Status: "valid", - Expires: time.Now(), // "soon" + ExpiresAt: time.Now(), // "soon" FinalizeURL: "finalize-url", CertificateURL: "cert-url", } From ce13d09dcba7a81cb6a276ca37c83b53d25feec0 Mon Sep 17 00:00:00 2001 From: max furman Date: Fri, 19 Mar 2021 15:01:26 -0700 Subject: [PATCH 24/47] add `at` to time attributes in dbAccount --- acme/db/nosql/account.go | 24 ++++++++++++------------ acme/db/nosql/order.go | 4 +--- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index befeb54d..0e0a7c4b 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -13,12 +13,12 @@ import ( // dbAccount represents an ACME account. type dbAccount struct { - ID string `json:"id"` - Created time.Time `json:"created"` - Deactivated time.Time `json:"deactivated"` - Key *jose.JSONWebKey `json:"key"` - Contact []string `json:"contact,omitempty"` - Status acme.Status `json:"status"` + ID string `json:"id"` + CreatedAt time.Time `json:"createdAt"` + DeactivatedAt time.Time `json:"deactivatedAt"` + Key *jose.JSONWebKey `json:"key"` + Contact []string `json:"contact,omitempty"` + Status acme.Status `json:"status"` } func (dba *dbAccount) clone() *dbAccount { @@ -35,11 +35,11 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { } dba := &dbAccount{ - ID: acc.ID, - Key: acc.Key, - Contact: acc.Contact, - Status: acc.Status, - Created: clock.Now(), + ID: acc.ID, + Key: acc.Key, + Contact: acc.Contact, + Status: acc.Status, + CreatedAt: clock.Now(), } kid, err := acme.KeyToID(dba.Key) @@ -105,7 +105,7 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { // If the status has changed to 'deactivated', then set deactivatedAt timestamp. if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated { - nu.Deactivated = clock.Now() + nu.DeactivatedAt = clock.Now() } return db.save(ctx, old.ID, nu, old, "account", accountTable) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index bc89442a..862c32df 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -105,9 +105,6 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { type orderIDsByAccount struct{} -// addOrderID adds an order ID to a users index of in progress order IDs. -// This method will also cull any orders that are no longer in the `pending` -// state from the index before returning it. func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) { ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() @@ -157,6 +154,7 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st return pendOids, nil } +// GetOrdersByAccountID returns a list of order IDs owned by the account. func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { return db.updateAddOrderIDs(ctx, accID) } From 88e6f0034742ac1ba98bbd5bb48487fdcff35e81 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 22 Mar 2021 14:46:05 -0700 Subject: [PATCH 25/47] nosql account db unit tests --- acme/db/nosql/account.go | 108 +++-- acme/db/nosql/account_test.go | 752 ++++++++++++++++++++++++++++++++++ 2 files changed, 804 insertions(+), 56 deletions(-) create mode 100644 acme/db/nosql/account_test.go diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 0e0a7c4b..3115e8ab 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -26,6 +26,58 @@ func (dba *dbAccount) clone() *dbAccount { return &nu } +func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { + id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return "", acme.NewError(acme.ErrorMalformedType, "account with key-id %s not found", kid) + } + return "", errors.Wrapf(err, "error loading key-account index for key %s", kid) + } + return string(id), nil +} + +// getDBAccount retrieves and unmarshals dbAccount. +func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { + data, err := db.db.Get(accountTable, []byte(id)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "account %s not found", id) + } + return nil, errors.Wrapf(err, "error loading account %s", id) + } + + dbacc := new(dbAccount) + if err = json.Unmarshal(data, dbacc); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id) + } + return dbacc, nil +} + +// GetAccount retrieves an ACME account by ID. +func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { + dbacc, err := db.getDBAccount(ctx, id) + if err != nil { + return nil, err + } + + return &acme.Account{ + Status: dbacc.Status, + Contact: dbacc.Contact, + Key: dbacc.Key, + ID: dbacc.ID, + }, nil +} + +// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). +func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { + id, err := db.getAccountIDByKeyID(ctx, kid) + if err != nil { + return nil, err + } + return db.GetAccount(ctx, id) +} + // CreateAccount imlements the AcmeDB.CreateAccount interface. func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { var err error @@ -64,36 +116,8 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { } } -// GetAccount retrieves an ACME account by ID. -func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { - dbacc, err := db.getDBAccount(ctx, id) - if err != nil { - return nil, err - } - - return &acme.Account{ - Status: dbacc.Status, - Contact: dbacc.Contact, - Key: dbacc.Key, - ID: dbacc.ID, - }, nil -} - -// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). -func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { - id, err := db.getAccountIDByKeyID(ctx, kid) - if err != nil { - return nil, err - } - return db.GetAccount(ctx, id) -} - // UpdateAccount imlements the AcmeDB.UpdateAccount interface. func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { - if len(acc.ID) == 0 { - return errors.New("id cannot be empty") - } - old, err := db.getDBAccount(ctx, acc.ID) if err != nil { return err @@ -110,31 +134,3 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { return db.save(ctx, old.ID, nu, old, "account", accountTable) } - -func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { - id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) - if err != nil { - if nosqlDB.IsErrNotFound(err) { - return "", errors.Wrapf(err, "account with key id %s not found", kid) - } - return "", errors.Wrapf(err, "error loading key-account index") - } - return string(id), nil -} - -// getDBAccount retrieves and unmarshals dbAccount. -func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { - data, err := db.db.Get(accountTable, []byte(id)) - if err != nil { - if nosqlDB.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "account %s not found", id) - } - return nil, errors.Wrapf(err, "error loading account %s", id) - } - - dbacc := new(dbAccount) - if err = json.Unmarshal(data, dbacc); err != nil { - return nil, errors.Wrap(err, "error unmarshaling account") - } - return dbacc, nil -} diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go new file mode 100644 index 00000000..9f889e64 --- /dev/null +++ b/acme/db/nosql/account_test.go @@ -0,0 +1,752 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + "go.step.sm/crypto/jose" +) + +func TestDB_getDBAccount(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling account accID into dbAccount"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbacc.ID, tc.dbacc.ID) + assert.Equals(t, dbacc.Status, tc.dbacc.Status) + assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) + assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) + assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) + assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_getAccountIDByKeyID(t *testing.T) { + accID := "accID" + kid := "kid" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading key-account index for key kid: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return []byte(accID), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, retAccID, accID) + } + } + }) + } +} + +func TestDB_GetAccount(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + return b, nil + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if acc, err := db.GetAccount(context.Background(), accID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_GetAccountByKeyID(t *testing.T) { + accID := "accID" + kid := "kid" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.getAccountIDByKeyID-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(accountByKeyIDTable)) + assert.Equals(t, string(key), kid) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading key-account index for key kid: force"), + } + }, + "fail/db.getAccountIDByKeyID-forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(accountByKeyIDTable)) + assert.Equals(t, string(key), kid) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), + } + }, + "fail/db.GetAccount-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/db.GetAccount-forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return nil, nosqldb.ErrNotFound + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_CreateAccount(t *testing.T) { + type test struct { + db nosql.DB + acc *acme.Account + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/keyID-cmpAndSwap-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + + assert.Equals(t, nu, []byte(acc.ID)) + return nil, false, errors.New("force") + }, + }, + acc: acc, + err: errors.New("error storing keyID to accountID index: force"), + } + }, + "fail/keyID-cmpAndSwap-false": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + + assert.Equals(t, nu, []byte(acc.ID)) + return nil, false, nil + }, + }, + acc: acc, + err: errors.New("key-id to account-id index already exists"), + } + }, + "fail/account-save-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + return nu, true, nil + case string(accountTable): + assert.Equals(t, string(key), acc.ID) + assert.Equals(t, old, nil) + + dbacc := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbacc)) + assert.Equals(t, dbacc.ID, string(key)) + assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) + assert.True(t, dbacc.DeactivatedAt.IsZero()) + return nil, false, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + acc: acc, + err: errors.New("error saving acme account: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + id = string(key) + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + return nu, true, nil + case string(accountTable): + assert.Equals(t, string(key), acc.ID) + assert.Equals(t, old, nil) + + dbacc := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbacc)) + assert.Equals(t, dbacc.ID, string(key)) + assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) + assert.True(t, dbacc.DeactivatedAt.IsZero()) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + acc: acc, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateAccount(context.Background(), tc.acc); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, *tc._id) + } + } + }) + } +} + +func TestDB_UpdateAccount(t *testing.T) { + accID := "accID" + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + type test struct { + db nosql.DB + acc *acme.Account + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + acc: &acme.Account{ + ID: accID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/already-deactivated": func(t *testing.T) test { + clone := dbacc.clone() + clone.Status = acme.StatusDeactivated + clone.DeactivatedAt = now + dbaccb, err := json.Marshal(clone) + assert.FatalError(t, err) + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return dbaccb, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, clone.ID) + assert.Equals(t, dbNew.Status, clone.Status) + assert.Equals(t, dbNew.Contact, clone.Contact) + assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt) + assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme account: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbacc.ID) + assert.Equals(t, dbNew.Status, acc.Status) + assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) + assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme account: force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbacc.ID) + assert.Equals(t, dbNew.Status, acc.Status) + assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) + assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateAccount(context.Background(), tc.acc); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, dbacc.ID) + assert.Equals(t, tc.acc.Status, dbacc.Status) + assert.Equals(t, tc.acc.Contact, dbacc.Contact) + assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID) + } + } + }) + } +} From 7f9ffbd514185aa806cdbc3e248f82d4452abe6d Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 22 Mar 2021 22:28:56 -0700 Subject: [PATCH 26/47] adding more acme nosql unit tests --- acme/db/nosql/challenge_test.go | 2 +- acme/db/nosql/nosql_test.go | 126 ++++++++ acme/db/nosql/order.go | 90 +++--- acme/db/nosql/order_test.go | 557 ++++++++++++++++++++++++++++++++ 4 files changed, 730 insertions(+), 45 deletions(-) create mode 100644 acme/db/nosql/nosql_test.go create mode 100644 acme/db/nosql/order_test.go diff --git a/acme/db/nosql/challenge_test.go b/acme/db/nosql/challenge_test.go index 34af74ce..314fc5f7 100644 --- a/acme/db/nosql/challenge_test.go +++ b/acme/db/nosql/challenge_test.go @@ -101,7 +101,7 @@ func TestDB_getDBChallenge(t *testing.T) { assert.Equals(t, k.Type, tc.acmeErr.Type) assert.Equals(t, k.Detail, tc.acmeErr.Detail) assert.Equals(t, k.Status, tc.acmeErr.Status) - assert.Equals(t, k.Err, tc.acmeErr.Err) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, k.Detail, tc.acmeErr.Detail) } default: diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go new file mode 100644 index 00000000..b7a91a2f --- /dev/null +++ b/acme/db/nosql/nosql_test.go @@ -0,0 +1,126 @@ +package nosql + +import ( + "context" + "testing" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" +) + +func TestNew(t *testing.T) { + type test struct { + db nosql.DB + err error + } + var tests = map[string]test{ + "fail/db.CreateTable-error": test{ + db: &db.MockNoSQLDB{ + MCreateTable: func(bucket []byte) error { + assert.Equals(t, string(bucket), string(accountTable)) + return errors.New("force") + }, + }, + err: errors.Errorf("error creating table %s: force", string(accountTable)), + }, + "ok": test{ + db: &db.MockNoSQLDB{ + MCreateTable: func(bucket []byte) error { + return nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + if _, err := New(tc.db); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +type errorThrower string + +func (et errorThrower) MarshalJSON() ([]byte, error) { + return nil, errors.New("force") +} + +func TestDB_save(t *testing.T) { + type test struct { + db nosql.DB + nu interface{} + old interface{} + err error + } + var tests = map[string]test{ + "fail/error-marshaling-new": test{ + nu: errorThrower("foo"), + err: errors.New("error marshaling acme type: challenge"), + }, + "fail/error-marshaling-old": test{ + nu: "new", + old: errorThrower("foo"), + err: errors.New("error marshaling acme type: challenge"), + }, + "fail/db.CmpAndSwap-error": test{ + nu: "new", + old: "old", + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, string(old), "\"old\"") + assert.Equals(t, string(nu), "\"new\"") + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme challenge: force"), + }, + "fail/db.CmpAndSwap-false-marshaling-old": test{ + nu: "new", + old: "old", + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, string(old), "\"old\"") + assert.Equals(t, string(nu), "\"new\"") + return nil, false, nil + }, + }, + err: errors.New("error saving acme challenge; changed since last read"), + }, + "ok": test{ + nu: "new", + old: "old", + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, string(old), "\"old\"") + assert.Equals(t, string(nu), "\"new\"") + return nu, true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + db := &DB{db: tc.db} + if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 862c32df..a64316a6 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -15,18 +15,18 @@ import ( var ordersByAccountMux sync.Mutex type dbOrder struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - ProvisionerID string `json:"provisionerID"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires,omitempty"` - Status acme.Status `json:"status"` - Identifiers []acme.Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore,omitempty"` - NotAfter time.Time `json:"notAfter,omitempty"` - Error *acme.Error `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - CertificateID string `json:"certificate,omitempty"` + ID string `json:"id"` + AccountID string `json:"accountID"` + ProvisionerID string `json:"provisionerID"` + CreatedAt time.Time `json:"createdAt"` + ExpiresAt time.Time `json:"expiresAt,omitempty"` + Status acme.Status `json:"status"` + Identifiers []acme.Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + Error *acme.Error `json:"error,omitempty"` + AuthorizationIDs []string `json:"authorizationIDs"` + CertificateID string `json:"certificate,omitempty"` } func (a *dbOrder) clone() *dbOrder { @@ -38,13 +38,13 @@ func (a *dbOrder) clone() *dbOrder { func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) { b, err := db.db.Get(orderTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id) + return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading order %s", id) } o := new(dbOrder) if err := json.Unmarshal(b, &o); err != nil { - return nil, errors.Wrap(err, "error unmarshaling order") + return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id) } return o, nil } @@ -57,15 +57,17 @@ func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { } o := &acme.Order{ + ID: dbo.ID, + AccountID: dbo.AccountID, + ProvisionerID: dbo.ProvisionerID, + CertificateID: dbo.CertificateID, Status: dbo.Status, - ExpiresAt: dbo.Expires, + ExpiresAt: dbo.ExpiresAt, Identifiers: dbo.Identifiers, NotBefore: dbo.NotBefore, NotAfter: dbo.NotAfter, - AuthorizationIDs: dbo.Authorizations, - ID: dbo.ID, - ProvisionerID: dbo.ProvisionerID, - CertificateID: dbo.CertificateID, + AuthorizationIDs: dbo.AuthorizationIDs, + Error: dbo.Error, } return o, nil @@ -81,16 +83,16 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { now := clock.Now() dbo := &dbOrder{ - ID: o.ID, - AccountID: o.AccountID, - ProvisionerID: o.ProvisionerID, - Created: now, - Status: acme.StatusPending, - Expires: o.ExpiresAt, - Identifiers: o.Identifiers, - NotBefore: o.NotBefore, - NotAfter: o.NotBefore, - Authorizations: o.AuthorizationIDs, + ID: o.ID, + AccountID: o.AccountID, + ProvisionerID: o.ProvisionerID, + Status: o.Status, + CreatedAt: now, + ExpiresAt: o.ExpiresAt, + Identifiers: o.Identifiers, + NotBefore: o.NotBefore, + NotAfter: o.NotBefore, + AuthorizationIDs: o.AuthorizationIDs, } if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { return err @@ -103,6 +105,21 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { return nil } +// UpdateOrder saves an updated ACME Order to the database. +func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { + old, err := db.getDBOrder(ctx, o.ID) + if err != nil { + return err + } + + nu := old.clone() + + nu.Status = o.Status + nu.Error = o.Error + nu.CertificateID = o.CertificateID + return db.save(ctx, old.ID, nu, old, "order", orderTable) +} + type orderIDsByAccount struct{} func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) { @@ -158,18 +175,3 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { return db.updateAddOrderIDs(ctx, accID) } - -// UpdateOrder saves an updated ACME Order to the database. -func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { - old, err := db.getDBOrder(ctx, o.ID) - if err != nil { - return err - } - - nu := old.clone() - - nu.Status = o.Status - nu.Error = o.Error - nu.CertificateID = o.CertificateID - return db.save(ctx, old.ID, nu, old, "order", orderTable) -} diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go new file mode 100644 index 00000000..8ce7ac79 --- /dev/null +++ b/acme/db/nosql/order_test.go @@ -0,0 +1,557 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBOrder(t *testing.T) { + orderID := "orderID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbo *dbOrder + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading order orderID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling order orderID into dbOrder"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbo := &dbOrder{ + ID: orderID, + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + CreatedAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + b, err := json.Marshal(dbo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return b, nil + }, + }, + dbo: dbo, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbo.ID, tc.dbo.ID) + assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) + assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) + assert.Equals(t, dbo.Status, tc.dbo.Status) + assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) + assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) + assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) + assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) + assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) + assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) + assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error()) + } + } + }) + } +} + +func TestDB_GetOrder(t *testing.T) { + orderID := "orderID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbo *dbOrder + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading order orderID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbo := &dbOrder{ + ID: orderID, + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + CreatedAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + b, err := json.Marshal(dbo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + return b, nil + }, + }, + dbo: dbo, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if o, err := db.GetOrder(context.Background(), orderID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, o.ID, tc.dbo.ID) + assert.Equals(t, o.AccountID, tc.dbo.AccountID) + assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) + assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) + assert.Equals(t, o.Status, tc.dbo.Status) + assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) + assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) + assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) + assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) + assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) + assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error()) + } + } + }) + } +} + +func TestDB_UpdateOrder(t *testing.T) { + orderID := "orderID" + now := clock.Now() + dbo := &dbOrder{ + ID: orderID, + AccountID: "accID", + ProvisionerID: "provID", + Status: acme.StatusPending, + ExpiresAt: now, + CreatedAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + b, err := json.Marshal(dbo) + assert.FatalError(t, err) + type test struct { + db nosql.DB + o *acme.Order + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + o: &acme.Order{ + ID: orderID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading order orderID: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + o := &acme.Order{ + ID: orderID, + Status: acme.StatusValid, + CertificateID: "certID", + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + return test{ + o: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, old, b) + + dbNew := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbo.ID) + assert.Equals(t, dbNew.AccountID, dbo.AccountID) + assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID) + assert.Equals(t, dbNew.CertificateID, o.CertificateID) + assert.Equals(t, dbNew.Status, o.Status) + assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt) + assert.Equals(t, dbNew.NotBefore, dbo.NotBefore) + assert.Equals(t, dbNew.NotAfter, dbo.NotAfter) + assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs) + assert.Equals(t, dbNew.Identifiers, dbo.Identifiers) + assert.Equals(t, dbNew.Error.Error(), o.Error.Error()) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme order: force"), + } + }, + "ok": func(t *testing.T) test { + o := &acme.Order{ + ID: orderID, + Status: acme.StatusValid, + CertificateID: "certID", + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + return test{ + o: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, old, b) + + dbNew := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbo.ID) + assert.Equals(t, dbNew.AccountID, dbo.AccountID) + assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID) + assert.Equals(t, dbNew.CertificateID, o.CertificateID) + assert.Equals(t, dbNew.Status, o.Status) + assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt) + assert.Equals(t, dbNew.NotBefore, dbo.NotBefore) + assert.Equals(t, dbNew.NotAfter, dbo.NotAfter) + assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs) + assert.Equals(t, dbNew.Identifiers, dbo.Identifiers) + assert.Equals(t, dbNew.Error.Error(), o.Error.Error()) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateOrder(context.Background(), tc.o); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.o.ID, dbo.ID) + assert.Equals(t, tc.o.CertificateID, "certID") + assert.Equals(t, tc.o.Status, acme.StatusValid) + assert.Equals(t, tc.o.Error.Error(), acme.NewError(acme.ErrorMalformedType, "force").Error()) + } + } + }) + } +} + +func TestDB_CreateOrder(t *testing.T) { + now := clock.Now() + type test struct { + db nosql.DB + o *acme.Order + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/order-save-error": func(t *testing.T) test { + o := &acme.Order{ + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, string(bucket), string(orderTable)) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) + + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nil, false, errors.New("force") + }, + }, + o: o, + err: errors.New("error saving acme order: force"), + } + }, + "fail/orderIDsByOrderUpdate-error": func(t *testing.T) test { + o := &acme.Order{ + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) + assert.Equals(t, string(key), o.AccountID) + return nil, errors.New("force") + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, string(bucket), string(orderTable)) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) + + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nu, true, nil + }, + }, + o: o, + err: errors.New("error loading orderIDs for account accID: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idptr = &id + ) + + o := &acme.Order{ + AccountID: "accID", + ProvisionerID: "provID", + Status: acme.StatusValid, + ExpiresAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) + assert.Equals(t, string(key), o.AccountID) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idptr = string(key) + assert.Equals(t, string(bucket), string(orderTable)) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) + + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nu, true, nil + }, + }, + o: o, + _id: idptr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateOrder(context.Background(), tc.o); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.o.ID, *tc._id) + } + } + }) + } +} From 3612a0b990ca30ce29e1072a087d91a9f31f86c1 Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 23 Mar 2021 22:12:25 -0700 Subject: [PATCH 27/47] gethttp01 validate unit tests working --- acme/challenge_test.go | 2902 ++++++++++++++++++----------------- acme/db/nosql/authz.go | 42 +- acme/db/nosql/authz_test.go | 76 +- acme/db/nosql/nosql.go | 14 +- acme/db/nosql/nosql_test.go | 13 + acme/db/nosql/order.go | 38 +- acme/db/nosql/order_test.go | 484 +++++- 7 files changed, 2078 insertions(+), 1491 deletions(-) diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 11b30961..29bd5a71 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -1,373 +1,431 @@ package acme -/* -var testOps = ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: "zap.internal", - }, -} - -func newDNSCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newDNS01Challenge(mockdb, testOps) -} - -func newTLSALPNCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newTLSALPN01Challenge(mockdb, testOps) -} - -func newHTTPCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, testOps) -} - -func newHTTPChWithServer(host string) (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: host, - }, - }) -} +import ( + "bytes" + "context" + "crypto" + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "go.step.sm/crypto/jose" +) -func TestNewHTTP01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } +func TestKeyAuthorization(t *testing.T) { type test struct { - ops ChallengeOptions - db nosql.DB - err *Error + token string + jwk *jose.JSONWebKey + exp string + err error } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + tests := map[string]func(t *testing.T) test{ + "fail/jwk-thumbprint-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + return test{ + token: "1234", + jwk: jwk, + err: errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), + } }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, + "ok": func(t *testing.T) test { + token := "1234" + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + thumbprint, err := jwk.Thumbprint(crypto.SHA256) + assert.FatalError(t, err) + encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) + return test{ + token: token, + jwk: jwk, + exp: fmt.Sprintf("%s.%s", token, encPrint), + } }, } - for name, tc := range tests { + for name, run := range tests { t.Run(name, func(t *testing.T) { - ch, err := newHTTP01Challenge(tc.db, tc.ops) - if err != nil { + tc := run(t) + if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "http-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") + assert.Equals(t, tc.exp, ka) } } }) } } -func TestNewTLSALPN01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } +type errReader int + +func (errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("force") +} +func (errReader) Close() error { + return nil +} + +func TestHTTP01Validate(t *testing.T) { type test struct { - ops ChallengeOptions - db nosql.DB + vo *ValidateChallengeOptions + ch *Challenge + jwk *jose.JSONWebKey + db DB err *Error } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + tests := map[string]func(t *testing.T) test{ + "fail/http-get-error-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, }, - }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newTLSALPN01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "tls-alpn-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } + "ok/http-get-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", } - }) - } -} -func TestNewDNS01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "dns", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil + "fail/http-get->=400-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + }, nil + }, }, - }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newDNS01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "dns-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } + "ok/http-get->=400": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", } - }) - } -} - -func TestChallengeToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Validated = clock.Now() - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - tests := map[string]challenge{ - "dns": dnsCh, - "http": httpCh, - "tls-alpn": tlsALPNCh, - } - for name, ch := range tests { - t.Run(name, func(t *testing.T) { - ach, err := ch.toACME(ctx, nil, dir) - assert.FatalError(t, err) - assert.Equals(t, ach.Type, ch.getType()) - assert.Equals(t, ach.Status, ch.getStatus()) - assert.Equals(t, ach.Token, ch.getToken()) - assert.Equals(t, ach.URL, - fmt.Sprintf("%s/acme/%s/challenge/%s", - baseURL.String(), provName, ch.getID())) - assert.Equals(t, ach.ID, ch.getID()) - assert.Equals(t, ach.AuthzID, ch.getAuthzID()) + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + }, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/read-body": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } - if ach.Type == "http-01" { - v, err := time.Parse(time.RFC3339, ach.Validated) - assert.FatalError(t, err) - assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) - } else { - assert.Equals(t, ach.Validated, "") + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: errReader(0), + }, nil + }, + }, + err: NewErrorISE("error reading response body for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token), + } + }, + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", } - }) - } -} -func TestChallengeSave(t *testing.T) { - type test struct { - ch challenge - old challenge - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - httpCh, err := newHTTPCh() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + jwk.Key = "foo" return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), + }, nil }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + jwk: jwk, + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, - "fail/old-nil/swap-false": func(t *testing.T) test { - httpCh, err := newHTTPCh() + "ok/key-auth-mismatch": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), + }, nil + }, + }, + jwk: jwk, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), } }, - "ok/old-nil": func(t *testing.T) test { - httpCh, err := newHTTPCh() + "fail/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - b, err := json.Marshal(httpCh) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), + }, nil + }, + }, + jwk: jwk, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/old-not-nil": func(t *testing.T) test { - oldHTTPCh, err := newHTTPCh() + "fail/update-challenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - httpCh, err := newHTTPCh() + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + }, nil + }, + }, + jwk: jwk, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) - oldb, err := json.Marshal(oldHTTPCh) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating challenge: force"), + } + }, + "ok": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - b, err := json.Marshal(httpCh) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) return test{ - ch: httpCh, - old: oldHTTPCh, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + }, nil + }, + }, + jwk: jwk, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) + return nil }, }, } @@ -376,13 +434,18 @@ func TestChallengeSave(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := tc.ch.save(tc.db, tc.old); err != nil { + if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { assert.Nil(t, tc.err) @@ -391,451 +454,327 @@ func TestChallengeSave(t *testing.T) { } } -func TestChallengeClone(t *testing.T) { - ch, err := newHTTPCh() - assert.FatalError(t, err) - - clone := ch.clone() - - assert.Equals(t, clone.getID(), ch.getID()) - assert.Equals(t, clone.getAccountID(), ch.getAccountID()) - assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, clone.getStatus(), ch.getStatus()) - assert.Equals(t, clone.getToken(), ch.getToken()) - assert.Equals(t, clone.getCreated(), ch.getCreated()) - assert.Equals(t, clone.getValidated(), ch.getValidated()) - - clone.Status = StatusValid - - assert.NotEquals(t, clone.getStatus(), ch.getStatus()) -} - -func TestChallengeUnmarshal(t *testing.T) { +/* +func TestTLSALPN01Validate(t *testing.T) { type test struct { + srv *httptest.Server + vo validateOptions ch challenge - chb []byte + res challenge + jwk *jose.JSONWebKey + db nosql.DB err *Error } tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { + "ok/status-already-valid": func(t *testing.T) test { + ch, err := newTLSALPNCh() + assert.FatalError(t, err) + _ch, ok := ch.(*tlsALPN01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusValid + return test{ - chb: nil, - err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), + ch: ch, + res: ch, } }, - "fail/unexpected-type-http": func(t *testing.T) test { - httpCh, err := newHTTPCh() + "ok/status-already-invalid": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Type = "foo" - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _tlsALPNCh.baseChallenge.Type = "foo" - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) + _ch, ok := ch.(*tlsALPN01Challenge) assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - chb: b, - } - }, - "ok/http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - "ok/alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) + _ch.baseChallenge.Status = StatusInvalid + return test{ - ch: tlsALPNCh, - chb: b, + ch: ch, + res: ch, } }, - "ok/err": func(t *testing.T) test { - httpCh, err := newHTTPCh() + "ok/tls-dial-error": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() - b, err := json.Marshal(httpCh) + oldb, err := json.Marshal(ch) assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := unmarshalChallenge(tc.chb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } - } - }) - } -} -func TestGetChallenge(t *testing.T) { - type test struct { - id string - db nosql.DB - ch challenge - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - dnsCh, err := newDNSCh() + + expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: force", ch.getValue())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) assert.FatalError(t, err) + return test{ - ch: dnsCh, - id: dnsCh.getID(), + ch: ch, + vo: validateOptions{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return nil, errors.New("force") + }, + }, db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil }, }, - err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), + res: ch, } }, - "fail/db-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() + "ok/timeout": func(t *testing.T) test { + ch, err := newTLSALPNCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.getValue())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(nil) + // srv.Start() - do not start server to cause timeout + return test{ - ch: dnsCh, - id: dnsCh.getID(), + srv: srv, + ch: ch, + vo: validateOptions{ + tlsDial: tlsDial, + }, db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, string(newval), string(newb)) + return nil, true, nil }, }, - err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), + res: ch, } }, - "fail/unmarshal-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() + "ok/no-certificates": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) + oldb, err := json.Marshal(ch) assert.FatalError(t, err) + + expErr := RejectedIdentifierErr(errors.Errorf("tls-alpn-01 challenge for %v resulted in no certificates", ch.getValue())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + return test{ - ch: dnsCh, - id: dnsCh.getID(), + ch: ch, + vo: validateOptions{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Client(&noopConn{}, config), nil + }, + }, db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, string(newval), string(newb)) + return nil, true, nil }, }, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), + res: ch, } }, - "ok": func(t *testing.T) test { - dnsCh, err := newDNSCh() + "ok/no-names": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ - ch: dnsCh, - id: dnsCh.getID(), + srv: srv, + ch: ch, + vo: validateOptions{ + tlsDial: tlsDial, + }, + jwk: jwk, db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, string(newval), string(newb)) + return nil, true, nil }, }, + res: ch, } }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := getChallenge(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } - } - }) - } -} + "ok/too-many-names": func(t *testing.T) test { + ch, err := newTLSALPNCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) -func TestKeyAuthorization(t *testing.T) { - type test struct { - token string - jwk *jose.JSONWebKey - exp string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/jwk-thumbprint-error": func(t *testing.T) test { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) assert.FatalError(t, err) - jwk.Key = "foo" - return test{ - token: "1234", - jwk: jwk, - err: ServerInternalErr(errors.Errorf("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "ok": func(t *testing.T) test { - token := "1234" + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - thumbprint, err := jwk.Thumbprint(crypto.SHA256) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) assert.FatalError(t, err) - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - return test{ - token: token, - jwk: jwk, - exp: fmt.Sprintf("%s.%s", token, encPrint), - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.exp, ka) - } - } - }) - } -} + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) -type errReader int + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue(), "other.internal") + assert.FatalError(t, err) -func (errReader) Read(p []byte) (n int, err error) { - return 0, errors.New("force") -} -func (errReader) Close() error { - return nil -} + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() -func TestHTTP01Validate(t *testing.T) { - type test struct { - vo validateOptions - ch challenge - res challenge - jwk *jose.JSONWebKey - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - return test{ - ch: ch, - res: ch, - } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid return test{ + srv: srv, ch: ch, + vo: validateOptions{ + tlsDial: tlsDial, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, string(newval), string(newb)) + return nil, true, nil + }, + }, res: ch, } }, - "ok/http-get-error": func(t *testing.T) test { - ch, err := newHTTPCh() + "ok/wrong-name": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ - "http://zap.internal/.well-known/acme-challenge/%s: force", ch.getToken())) + expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) baseClone := ch.clone() baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} + newCh := &tlsALPN01Challenge{baseClone} newb, err := json.Marshal(newCh) assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ - ch: ch, + srv: srv, + ch: ch, vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return nil, errors.New("force") - }, + tlsDial: tlsDial, }, + jwk: jwk, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, key, []byte(ch.getID())) assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) + assert.Equals(t, string(newval), string(newb)) return nil, true, nil }, }, res: ch, } }, - "ok/http-get->=400": func(t *testing.T) test { - ch, err := newHTTPCh() + "ok/no-extension": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ - "http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.getToken())) + expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) baseClone := ch.clone() baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} + newCh := &tlsALPN01Challenge{baseClone} newb, err := json.Marshal(newCh) assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + cert, err := newTLSALPNValidationCert(nil, false, true, ch.getValue()) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ - ch: ch, + srv: srv, + ch: ch, vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusBadRequest, - }, nil - }, + tlsDial: tlsDial, }, + jwk: jwk, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, key, []byte(ch.getID())) assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) + assert.Equals(t, string(newval), string(newb)) return nil, true, nil }, }, res: ch, } }, - "fail/read-body": func(t *testing.T) test { - ch, err := newHTTPCh() + "ok/extension-not-critical": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + oldb, err := json.Marshal(ch) assert.FatalError(t, err) - jwk.Key = "foo" - return test{ - ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return &http.Response{ - Body: errReader(0), - }, nil - }, - }, - jwk: jwk, - err: ServerInternalErr(errors.Errorf("error reading response "+ - "body for url http://zap.internal/.well-known/acme-challenge/%s: force", - ch.getToken())), - } - }, - "fail/key-authorization-gen-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - jwk.Key = "foo" - return test{ - ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return &http.Response{ - Body: ioutil.NopCloser(bytes.NewBufferString("foo")), - }, nil - }, - }, - jwk: jwk, - err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "ok/key-auth-mismatch": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -843,23 +782,19 @@ func TestHTTP01Validate(t *testing.T) { expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got foo", expKeyAuth)) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.getValue()) assert.FatalError(t, err) + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ - ch: ch, + srv: srv, + ch: ch, vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return &http.Response{ - Body: ioutil.NopCloser(bytes.NewBufferString("foo")), - }, nil - }, + tlsDial: tlsDial, }, jwk: jwk, db: &db.MockNoSQLDB{ @@ -867,69 +802,40 @@ func TestHTTP01Validate(t *testing.T) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, key, []byte(ch.getID())) assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) + assert.Equals(t, string(newval), string(newb)) return nil, true, nil }, }, res: ch, } }, - "fail/save-error": func(t *testing.T) test { - ch, err := newHTTPCh() + "ok/extension-malformed": func(t *testing.T) test { + ch, err := newTLSALPNCh() assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - return test{ - ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return &http.Response{ - Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), - }, nil - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - } - }, - "ok": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) + expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &tlsALPN01Challenge{baseClone} + newb, err := json.Marshal(newCh) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.getValue()) assert.FatalError(t, err) - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &http01Challenge{baseClone} + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() return test{ + srv: srv, ch: ch, - res: newCh, vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { - return &http.Response{ - Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), - }, nil - }, + tlsDial: tlsDial, }, jwk: jwk, db: &db.MockNoSQLDB{ @@ -937,137 +843,40 @@ func TestHTTP01Validate(t *testing.T) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, key, []byte(ch.getID())) assert.Equals(t, old, oldb) - - httpCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, httpCh.getStatus(), StatusValid) - assert.True(t, httpCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, httpCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) - - baseClone.Validated = httpCh.getValidated() - + assert.Equals(t, string(newval), string(newb)) return nil, true, nil }, }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } - } - }) - } -} - -func TestTLSALPN01Validate(t *testing.T) { - type test struct { - srv *httptest.Server - vo validateOptions - ch challenge - res challenge - jwk *jose.JSONWebKey - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - - return test{ - ch: ch, - res: ch, - } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid - - return test{ - ch: ch, res: ch, } }, - "ok/tls-dial-error": func(t *testing.T) test { + "ok/no-protocol": func(t *testing.T) test { ch, err := newTLSALPNCh() assert.FatalError(t, err) oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: force", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - return test{ - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return nil, errors.New("force") - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/timeout": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.getValue())) + expErr := RejectedIdentifierErr(errors.New("cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) baseClone := ch.clone() baseClone.Error = expErr.ToACME() newCh := &tlsALPN01Challenge{baseClone} newb, err := json.Marshal(newCh) assert.FatalError(t, err) - srv, tlsDial := newTestTLSALPNServer(nil) - // srv.Start() - do not start server to cause timeout + srv := httptest.NewTLSServer(nil) return test{ srv: srv, ch: ch, vo: validateOptions{ - tlsDial: tlsDial, + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + }, }, + jwk: jwk, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) @@ -1080,59 +889,30 @@ func TestTLSALPN01Validate(t *testing.T) { res: ch, } }, - "ok/no-certificates": func(t *testing.T) test { + "ok/mismatched-token": func(t *testing.T) test { ch, err := newTLSALPNCh() assert.FatalError(t, err) oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.Errorf("tls-alpn-01 challenge for %v resulted in no certificates", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - return test{ - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.Client(&noopConn{}, config), nil - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/no-names": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + incorrectTokenHash := sha256.Sum256([]byte("mismatched")) - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) + expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:]))) baseClone := ch.clone() baseClone.Error = expErr.ToACME() newCh := &tlsALPN01Challenge{baseClone} newb, err := json.Marshal(newCh) assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) + cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.getValue()) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) @@ -1157,27 +937,28 @@ func TestTLSALPN01Validate(t *testing.T) { res: ch, } }, - "ok/too-many-names": func(t *testing.T) test { + "ok/obsolete-oid": func(t *testing.T) test { ch, err := newTLSALPNCh() assert.FatalError(t, err) oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: " + + "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) baseClone := ch.clone() baseClone.Error = expErr.ToACME() newCh := &tlsALPN01Challenge{baseClone} newb, err := json.Marshal(newCh) assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) assert.FatalError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue(), "other.internal") + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.getValue()) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) @@ -1202,18 +983,19 @@ func TestTLSALPN01Validate(t *testing.T) { res: ch, } }, - "ok/wrong-name": func(t *testing.T) test { + "ok": func(t *testing.T) test { ch, err := newTLSALPNCh() assert.FatalError(t, err) + _ch, ok := ch.(*tlsALPN01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Error = MalformedErr(nil).ToACME() oldb, err := json.Marshal(ch) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) baseClone := ch.clone() - baseClone.Error = expErr.ToACME() + baseClone.Status = StatusValid + baseClone.Error = nil newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) @@ -1222,7 +1004,7 @@ func TestTLSALPN01Validate(t *testing.T) { assert.FatalError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue()) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) @@ -1232,7 +1014,15 @@ func TestTLSALPN01Validate(t *testing.T) { srv: srv, ch: ch, vo: validateOptions{ - tlsDial: tlsDial, + tlsDial: func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { + assert.Equals(t, network, "tcp") + assert.Equals(t, addr, net.JoinHostPort(newCh.getValue(), "443")) + assert.Equals(t, config.NextProtos, []string{"acme-tls/1"}) + assert.Equals(t, config.ServerName, newCh.getValue()) + assert.True(t, config.InsecureSkipVerify) + + return tlsDial(network, addr, config) + }, }, jwk: jwk, db: &db.MockNoSQLDB{ @@ -1240,65 +1030,272 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, key, []byte(ch.getID())) assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) + + alpnCh, err := unmarshalChallenge(newval) + assert.FatalError(t, err) + assert.Equals(t, alpnCh.getStatus(), StatusValid) + assert.True(t, alpnCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, alpnCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + + baseClone.Validated = alpnCh.getValidated() + return nil, true, nil }, }, - res: ch, + res: newCh, } }, - "ok/no-extension": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) - baseClone := ch.clone() + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + if tc.srv != nil { + defer tc.srv.Close() + } + + if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res.getID(), ch.getID()) + assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.res.getStatus(), ch.getStatus()) + assert.Equals(t, tc.res.getToken(), ch.getToken()) + assert.Equals(t, tc.res.getCreated(), ch.getCreated()) + assert.Equals(t, tc.res.getValidated(), ch.getValidated()) + assert.Equals(t, tc.res.getError(), ch.getError()) + } + } + }) + } +} + +func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { + srv := httptest.NewUnstartedServer(http.NewServeMux()) + + srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){ + "acme-tls/1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + // no-op + }, + "http/1.1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + panic("unexpected http/1.1 next proto") + }, + } + + srv.TLS = &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == "acme-tls/1" { + return validationCert, nil + } + return nil, nil + }, + NextProtos: []string{ + "acme-tls/1", + "http/1.1", + }, + } + + srv.Listener = tls.NewListener(srv.Listener, srv.TLS) + //srv.Config.ErrorLog = log.New(ioutil.Discard, "", 0) // hush + + return srv, func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + } +} + +// noopConn is a mock net.Conn that does nothing. +type noopConn struct{} + +func (c *noopConn) Read(_ []byte) (n int, err error) { return 0, io.EOF } +func (c *noopConn) Write(_ []byte) (n int, err error) { return 0, io.EOF } +func (c *noopConn) Close() error { return nil } +func (c *noopConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } +func (c *noopConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } +func (c *noopConn) SetDeadline(t time.Time) error { return nil } +func (c *noopConn) SetReadDeadline(t time.Time) error { return nil } +func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil } + +func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, names ...string) (*tls.Certificate, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + certTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1337), + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, 1), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: names, + } + + if keyAuthHash != nil { + oid := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} + if obsoleteOID { + oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} + } + + keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) + + certTemplate.ExtraExtensions = []pkix.Extension{ + { + Id: oid, + Critical: critical, + Value: keyAuthHashEnc, + }, + } + } + + cert, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, privateKey.Public(), privateKey) + if err != nil { + return nil, err + } + + return &tls.Certificate{ + PrivateKey: privateKey, + Certificate: [][]byte{cert}, + }, nil +} + +func TestDNS01Validate(t *testing.T) { + type test struct { + vo validateOptions + ch challenge + res challenge + jwk *jose.JSONWebKey + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/status-already-valid": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusValid + return test{ + ch: ch, + res: ch, + } + }, + "ok/status-already-invalid": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusInvalid + return test{ + ch: ch, + res: ch, + } + }, + "ok/lookup-txt-error": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + expErr := DNSErr(errors.Errorf("error looking up TXT records for "+ + "domain %s: force", ch.getValue())) + baseClone := ch.clone() baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} + newCh := &dns01Challenge{baseClone} newb, err := json.Marshal(newCh) assert.FatalError(t, err) + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "ok/lookup-txt-wildcard": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Value = "*.zap.internal" jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - cert, err := newTLSALPNValidationCert(nil, false, true, ch.getValue()) + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + baseClone := ch.clone() + baseClone.Status = StatusValid + baseClone.Error = nil + newCh := &dns01Challenge{baseClone} return test{ - srv: srv, ch: ch, + res: newCh, vo: validateOptions{ - tlsDial: tlsDial, + lookupTxt: func(url string) ([]string, error) { + assert.Equals(t, url, "_acme-challenge.zap.internal") + return []string{"foo", expected}, nil + }, }, jwk: jwk, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) + dnsCh, err := unmarshalChallenge(newval) + assert.FatalError(t, err) + assert.Equals(t, dnsCh.getStatus(), StatusValid) + baseClone.Validated = dnsCh.getValidated() return nil, true, nil }, }, - res: ch, } }, - "ok/extension-not-critical": func(t *testing.T) test { - ch, err := newTLSALPNCh() + "fail/key-authorization-gen-error": func(t *testing.T) test { + ch, err := newDNSCh() assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + jwk.Key = "foo" + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil + }, + }, + jwk: jwk, + err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + } + }, + "ok/key-auth-mismatch": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1306,279 +1303,517 @@ func TestTLSALPN01Validate(t *testing.T) { expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.getValue()) + expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ + "expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &http01Challenge{baseClone} + newb, err := json.Marshal(newCh) assert.FatalError(t, err) - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "fail/save-error": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + } + }, + "ok": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Error = MalformedErr(nil).ToACME() + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + + baseClone := ch.clone() + baseClone.Status = StatusValid + baseClone.Error = nil + newCh := &dns01Challenge{baseClone} + + return test{ + ch: ch, + res: newCh, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + + dnsCh, err := unmarshalChallenge(newval) + assert.FatalError(t, err) + assert.Equals(t, dnsCh.getStatus(), StatusValid) + assert.True(t, dnsCh.getValidated().Before(time.Now().UTC())) + assert.True(t, dnsCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + + baseClone.Validated = dnsCh.getValidated() + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res.getID(), ch.getID()) + assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.res.getStatus(), ch.getStatus()) + assert.Equals(t, tc.res.getToken(), ch.getToken()) + assert.Equals(t, tc.res.getCreated(), ch.getCreated()) + assert.Equals(t, tc.res.getValidated(), ch.getValidated()) + assert.Equals(t, tc.res.getError(), ch.getError()) + } + } + }) + } +} + +/* +var testOps = ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "", // will get set correctly depending on the "new.." method. + Value: "zap.internal", + }, +} + +func newDNSCh() (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newDNS01Challenge(mockdb, testOps) +} + +func newTLSALPNCh() (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newTLSALPN01Challenge(mockdb, testOps) +} + +func newHTTPCh() (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newHTTP01Challenge(mockdb, testOps) +} + +func newHTTPChWithServer(host string) (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newHTTP01Challenge(mockdb, ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "", // will get set correctly depending on the "new.." method. + Value: host, + }, + }) +} + +func TestNewHTTP01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "http", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newHTTP01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "http-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestNewTLSALPN01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "http", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newTLSALPN01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "tls-alpn-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestNewDNS01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "dns", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newDNS01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "dns-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } } - }, - "ok/extension-malformed": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + }) + } +} - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) +func TestChallengeToACME(t *testing.T) { + dir := newDirectory("ca.smallstep.com", "acme") - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Validated = clock.Now() + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + tlsALPNCh, err := newTLSALPNCh() + assert.FatalError(t, err) - cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.getValue()) + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) + tests := map[string]challenge{ + "dns": dnsCh, + "http": httpCh, + "tls-alpn": tlsALPNCh, + } + for name, ch := range tests { + t.Run(name, func(t *testing.T) { + ach, err := ch.toACME(ctx, nil, dir) assert.FatalError(t, err) - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + assert.Equals(t, ach.Type, ch.getType()) + assert.Equals(t, ach.Status, ch.getStatus()) + assert.Equals(t, ach.Token, ch.getToken()) + assert.Equals(t, ach.URL, + fmt.Sprintf("%s/acme/%s/challenge/%s", + baseURL.String(), provName, ch.getID())) + assert.Equals(t, ach.ID, ch.getID()) + assert.Equals(t, ach.AuthzID, ch.getAuthzID()) - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, + if ach.Type == "http-01" { + v, err := time.Parse(time.RFC3339, ach.Validated) + assert.FatalError(t, err) + assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) + } else { + assert.Equals(t, ach.Validated, "") } - }, - "ok/no-protocol": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + }) + } +} - expErr := RejectedIdentifierErr(errors.New("cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) +func TestChallengeSave(t *testing.T) { + type test struct { + ch challenge + old challenge + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + httpCh, err := newHTTPCh() assert.FatalError(t, err) - - srv := httptest.NewTLSServer(nil) - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) - }, - }, - jwk: jwk, + ch: httpCh, + old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + return nil, false, errors.New("force") }, }, - res: ch, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), } }, - "ok/mismatched-token": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - incorrectTokenHash := sha256.Sum256([]byte("mismatched")) - - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "expected acmeValidationV1 extension value %s for this challenge but got %s", - hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:]))) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.getValue()) + "fail/old-nil/swap-false": func(t *testing.T) test { + httpCh, err := newHTTPCh() assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, + ch: httpCh, + old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + return []byte("foo"), false, nil }, }, - res: ch, + err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), } }, - "ok/obsolete-oid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: " + - "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + "ok/old-nil": func(t *testing.T) test { + httpCh, err := newHTTPCh() assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.getValue()) + b, err := json.Marshal(httpCh) assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, + ch: httpCh, + old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, []byte(httpCh.getID()), key) + return []byte("foo"), true, nil }, }, - res: ch, } }, - "ok": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) + "ok/old-not-nil": func(t *testing.T) test { + oldHTTPCh, err := newHTTPCh() assert.FatalError(t, err) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &tlsALPN01Challenge{baseClone} - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + httpCh, err := newHTTPCh() assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + oldb, err := json.Marshal(oldHTTPCh) assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue()) + b, err := json.Marshal(httpCh) assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { - assert.Equals(t, network, "tcp") - assert.Equals(t, addr, net.JoinHostPort(newCh.getValue(), "443")) - assert.Equals(t, config.NextProtos, []string{"acme-tls/1"}) - assert.Equals(t, config.ServerName, newCh.getValue()) - assert.True(t, config.InsecureSkipVerify) - - return tlsDial(network, addr, config) - }, - }, - jwk: jwk, + ch: httpCh, + old: oldHTTPCh, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) assert.Equals(t, old, oldb) - - alpnCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, alpnCh.getStatus(), StatusValid) - assert.True(t, alpnCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, alpnCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) - - baseClone.Validated = alpnCh.getValidated() - - return nil, true, nil + assert.Equals(t, b, newval) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, []byte(httpCh.getID()), key) + return []byte("foo"), true, nil }, }, - res: newCh, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - - if tc.srv != nil { - defer tc.srv.Close() - } - - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if err := tc.ch.save(tc.db, tc.old); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1587,348 +1822,222 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, ae.Type, tc.err.Type) } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } + assert.Nil(t, tc.err) } }) } } -func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { - srv := httptest.NewUnstartedServer(http.NewServeMux()) - - srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){ - "acme-tls/1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { - // no-op - }, - "http/1.1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { - panic("unexpected http/1.1 next proto") - }, - } - - srv.TLS = &tls.Config{ - GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == "acme-tls/1" { - return validationCert, nil - } - return nil, nil - }, - NextProtos: []string{ - "acme-tls/1", - "http/1.1", - }, - } - - srv.Listener = tls.NewListener(srv.Listener, srv.TLS) - //srv.Config.ErrorLog = log.New(ioutil.Discard, "", 0) // hush - - return srv, func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { - return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) - } -} - -// noopConn is a mock net.Conn that does nothing. -type noopConn struct{} - -func (c *noopConn) Read(_ []byte) (n int, err error) { return 0, io.EOF } -func (c *noopConn) Write(_ []byte) (n int, err error) { return 0, io.EOF } -func (c *noopConn) Close() error { return nil } -func (c *noopConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } -func (c *noopConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } -func (c *noopConn) SetDeadline(t time.Time) error { return nil } -func (c *noopConn) SetReadDeadline(t time.Time) error { return nil } -func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil } - -func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, names ...string) (*tls.Certificate, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - - certTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1337), - Subject: pkix.Name{ - Organization: []string{"Test"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - DNSNames: names, - } - - if keyAuthHash != nil { - oid := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} - if obsoleteOID { - oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} - } +func TestChallengeClone(t *testing.T) { + ch, err := newHTTPCh() + assert.FatalError(t, err) - keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) + clone := ch.clone() - certTemplate.ExtraExtensions = []pkix.Extension{ - { - Id: oid, - Critical: critical, - Value: keyAuthHashEnc, - }, - } - } + assert.Equals(t, clone.getID(), ch.getID()) + assert.Equals(t, clone.getAccountID(), ch.getAccountID()) + assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, clone.getStatus(), ch.getStatus()) + assert.Equals(t, clone.getToken(), ch.getToken()) + assert.Equals(t, clone.getCreated(), ch.getCreated()) + assert.Equals(t, clone.getValidated(), ch.getValidated()) - cert, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, privateKey.Public(), privateKey) - if err != nil { - return nil, err - } + clone.Status = StatusValid - return &tls.Certificate{ - PrivateKey: privateKey, - Certificate: [][]byte{cert}, - }, nil + assert.NotEquals(t, clone.getStatus(), ch.getStatus()) } -func TestDNS01Validate(t *testing.T) { +func TestChallengeUnmarshal(t *testing.T) { type test struct { - vo validateOptions ch challenge - res challenge - jwk *jose.JSONWebKey - db nosql.DB + chb []byte err *Error } tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newDNSCh() + "fail/nil": func(t *testing.T) test { + return test{ + chb: nil, + err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), + } + }, + "fail/unexpected-type-http": func(t *testing.T) test { + httpCh, err := newHTTPCh() assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) + _httpCh, ok := httpCh.(*http01Challenge) assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid + _httpCh.baseChallenge.Type = "foo" + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) return test{ - ch: ch, - res: ch, + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), } }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newDNSCh() + "fail/unexpected-type-alpn": func(t *testing.T) test { + tlsALPNCh, err := newTLSALPNCh() assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) + _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge) assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid + _tlsALPNCh.baseChallenge.Type = "foo" + b, err := json.Marshal(tlsALPNCh) + assert.FatalError(t, err) return test{ - ch: ch, - res: ch, + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), } }, - "ok/lookup-txt-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + "fail/unexpected-type-dns": func(t *testing.T) test { + dnsCh, err := newDNSCh() assert.FatalError(t, err) - - expErr := DNSErr(errors.Errorf("error looking up TXT records for "+ - "domain %s: force", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &dns01Challenge{baseClone} - newb, err := json.Marshal(newCh) + _dnsCh, ok := dnsCh.(*dns01Challenge) + assert.Fatal(t, ok) + _dnsCh.baseChallenge.Type = "foo" + b, err := json.Marshal(dnsCh) assert.FatalError(t, err) - return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return nil, errors.New("force") - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil - }, - }, - res: ch, + return test{ + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), } }, - "ok/lookup-txt-wildcard": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Value = "*.zap.internal" - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + "ok/dns": func(t *testing.T) test { + dnsCh, err := newDNSCh() assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + b, err := json.Marshal(dnsCh) assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &dns01Challenge{baseClone} - return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - assert.Equals(t, url, "_acme-challenge.zap.internal") - return []string{"foo", expected}, nil - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - dnsCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, dnsCh.getStatus(), StatusValid) - baseClone.Validated = dnsCh.getValidated() - return nil, true, nil - }, - }, + ch: dnsCh, + chb: b, } }, - "fail/key-authorization-gen-error": func(t *testing.T) test { - ch, err := newDNSCh() + "ok/http": func(t *testing.T) test { + httpCh, err := newHTTPCh() assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + b, err := json.Marshal(httpCh) assert.FatalError(t, err) - jwk.Key = "foo" return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", "bar"}, nil - }, - }, - jwk: jwk, - err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + ch: httpCh, + chb: b, } }, - "ok/key-auth-mismatch": func(t *testing.T) test { - ch, err := newDNSCh() + "ok/alpn": func(t *testing.T) test { + tlsALPNCh, err := newTLSALPNCh() assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + b, err := json.Marshal(tlsALPNCh) assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + return test{ + ch: tlsALPNCh, + chb: b, + } + }, + "ok/err": func(t *testing.T) test { + httpCh, err := newHTTPCh() assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() + b, err := json.Marshal(httpCh) assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) + return test{ + ch: httpCh, + chb: b, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := unmarshalChallenge(tc.chb); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.getID(), ch.getID()) + assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) + assert.Equals(t, tc.ch.getToken(), ch.getToken()) + assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) + assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) + } + } + }) + } +} +func TestGetChallenge(t *testing.T) { + type test struct { + id string + db nosql.DB + ch challenge + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + dnsCh, err := newDNSCh() assert.FatalError(t, err) - return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", "bar"}, nil + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound }, }, - jwk: jwk, + err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), + } + }, + "fail/db-error": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") }, }, - res: ch, + err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), } }, - "fail/save-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + "fail/unmarshal-error": func(t *testing.T) test { + dnsCh, err := newDNSCh() assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + _dnsCh, ok := dnsCh.(*dns01Challenge) + assert.Fatal(t, ok) + _dnsCh.baseChallenge.Type = "foo" + b, err := json.Marshal(dnsCh) assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", expected}, nil - }, - }, - jwk: jwk, + ch: dnsCh, + id: dnsCh.getID(), db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(dnsCh.getID())) + return b, nil }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + err: ServerInternalErr(errors.New("unexpected challenge type foo")), } }, "ok": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + dnsCh, err := newDNSCh() assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + b, err := json.Marshal(dnsCh) assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &dns01Challenge{baseClone} - return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", expected}, nil - }, - }, - jwk: jwk, + ch: dnsCh, + id: dnsCh.getID(), db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - - dnsCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, dnsCh.getStatus(), StatusValid) - assert.True(t, dnsCh.getValidated().Before(time.Now().UTC())) - assert.True(t, dnsCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) - - baseClone.Validated = dnsCh.getValidated() - - return nil, true, nil + assert.Equals(t, key, []byte(dnsCh.getID())) + return b, nil }, }, } @@ -1937,7 +2046,7 @@ func TestDNS01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if ch, err := getChallenge(tc.db, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1947,14 +2056,13 @@ func TestDNS01Validate(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) + assert.Equals(t, tc.ch.getID(), ch.getID()) + assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) + assert.Equals(t, tc.ch.getToken(), ch.getToken()) + assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) + assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) } } }) diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 2ea1bb69..449a9276 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -14,16 +14,16 @@ var defaultExpiryDuration = time.Hour * 24 // dbAuthz is the base authz type that others build from. type dbAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier acme.Identifier `json:"identifier"` - Status acme.Status `json:"status"` - ExpiresAt time.Time `json:"expiresAt"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - CreatedAt time.Time `json:"createdAt"` - Error *acme.Error `json:"error"` - Token string `json:"token"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier acme.Identifier `json:"identifier"` + Status acme.Status `json:"status"` + ExpiresAt time.Time `json:"expiresAt"` + ChallengeIDs []string `json:"challengeIDs"` + Wildcard bool `json:"wildcard"` + CreatedAt time.Time `json:"createdAt"` + Error *acme.Error `json:"error"` + Token string `json:"token"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -55,8 +55,8 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat if err != nil { return nil, err } - var chs = make([]*acme.Challenge, len(dbaz.Challenges)) - for i, chID := range dbaz.Challenges { + var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs)) + for i, chID := range dbaz.ChallengeIDs { chs[i], err = db.GetChallenge(ctx, chID, id) if err != nil { return nil, err @@ -91,15 +91,15 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e now := clock.Now() dbaz := &dbAuthz{ - ID: az.ID, - AccountID: az.AccountID, - Status: az.Status, - CreatedAt: now, - ExpiresAt: az.ExpiresAt, - Identifier: az.Identifier, - Challenges: chIDs, - Token: az.Token, - Wildcard: az.Wildcard, + ID: az.ID, + AccountID: az.AccountID, + Status: az.Status, + CreatedAt: now, + ExpiresAt: az.ExpiresAt, + Identifier: az.Identifier, + ChallengeIDs: chIDs, + Token: az.Token, + Wildcard: az.Wildcard, } return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable) diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go index 825c4648..0c2cec50 100644 --- a/acme/db/nosql/authz_test.go +++ b/acme/db/nosql/authz_test.go @@ -71,13 +71,13 @@ func TestDB_getDBAuthz(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -174,13 +174,13 @@ func TestDB_GetAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -212,13 +212,13 @@ func TestDB_GetAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -250,13 +250,13 @@ func TestDB_GetAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -374,7 +374,7 @@ func TestDB_CreateAuthorization(t *testing.T) { }) assert.Equals(t, dbaz.Status, az.Status) assert.Equals(t, dbaz.Token, az.Token) - assert.Equals(t, dbaz.Challenges, []string{"foo", "bar"}) + assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) assert.Equals(t, dbaz.Wildcard, az.Wildcard) assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) assert.Nil(t, dbaz.Error) @@ -428,7 +428,7 @@ func TestDB_CreateAuthorization(t *testing.T) { }) assert.Equals(t, dbaz.Status, az.Status) assert.Equals(t, dbaz.Token, az.Token) - assert.Equals(t, dbaz.Challenges, []string{"foo", "bar"}) + assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) assert.Equals(t, dbaz.Wildcard, az.Wildcard) assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) assert.Nil(t, dbaz.Error) @@ -469,12 +469,12 @@ func TestDB_UpdateAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -530,7 +530,7 @@ func TestDB_UpdateAuthorization(t *testing.T) { assert.Equals(t, dbNew.Identifier, dbaz.Identifier) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.Token, dbaz.Token) - assert.Equals(t, dbNew.Challenges, dbaz.Challenges) + assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) @@ -580,7 +580,7 @@ func TestDB_UpdateAuthorization(t *testing.T) { assert.Equals(t, dbNew.Identifier, dbaz.Identifier) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.Token, dbaz.Token) - assert.Equals(t, dbNew.Challenges, dbaz.Challenges) + assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index bcb118d8..b8f79edc 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -42,9 +42,17 @@ func New(db nosqlDB.DB) (*DB, error) { // save writes the new data to the database, overwriting the old data if it // existed. func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { - newB, err := json.Marshal(nu) - if err != nil { - return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) + var ( + err error + newB []byte + ) + if nu == nil { + newB = nil + } else { + newB, err = json.Marshal(nu) + if err != nil { + return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) + } } var oldB []byte if old == nil { diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go index b7a91a2f..7fd21c50 100644 --- a/acme/db/nosql/nosql_test.go +++ b/acme/db/nosql/nosql_test.go @@ -110,6 +110,19 @@ func TestDB_save(t *testing.T) { }, }, }, + "ok/nils": test{ + nu: nil, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, old, nil) + assert.Equals(t, nu, nil) + return nu, true, nil + }, + }, + }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index a64316a6..c8fe53e1 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -127,15 +127,17 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st defer ordersByAccountMux.Unlock() b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) + var ( + oldOids []string + ) if err != nil { - if nosql.IsErrNotFound(err) { - return []string{}, nil + if !nosql.IsErrNotFound(err) { + return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) + } + } else { + if err := json.Unmarshal(b, &oldOids); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } - return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) - } - var oids []string - if err := json.Unmarshal(b, &oids); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } // Remove any order that is not in PENDING state and update the stored list @@ -145,7 +147,7 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st // The server SHOULD include pending orders and SHOULD NOT include orders // that are invalid in the array of URLs. pendOids := []string{} - for _, oid := range oids { + for _, oid := range oldOids { o, err := db.GetOrder(ctx, oid) if err != nil { return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID) @@ -158,15 +160,27 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st } } pendOids = append(pendOids, addOids...) - if len(oids) == 0 { - oids = nil + var ( + _old interface{} = oldOids + _new interface{} = pendOids + ) + switch { + case len(oldOids) == 0 && len(pendOids) == 0: + // If list has not changed from empty, then no need to write the DB. + return []string{}, nil + case len(oldOids) == 0: + _old = nil + case len(pendOids) == 0: + _new = nil } - if err = db.save(ctx, accID, pendOids, oids, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { + if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { // Delete all orders that may have been previously stored if orderIDsByAccountID update fails. for _, oid := range addOids { + // Ignore error from delete -- we tried our best. + // TODO when we have logging w/ request ID tracking, logging this error. db.db.Del(orderTable, []byte(oid)) } - return nil, errors.Wrap(err, "error saving OrderIDsByAccountID index") + return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID) } return pendOids, nil } diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go index 8ce7ac79..3636837c 100644 --- a/acme/db/nosql/order_test.go +++ b/acme/db/nosql/order_test.go @@ -3,6 +3,7 @@ package nosql import ( "context" "encoding/json" + "reflect" "testing" "time" @@ -511,27 +512,39 @@ func TestDB_CreateOrder(t *testing.T) { return nil, nosqldb.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { - *idptr = string(key) - assert.Equals(t, string(bucket), string(orderTable)) - assert.Equals(t, string(key), o.ID) - assert.Equals(t, old, nil) + switch string(bucket) { + case string(ordersByAccountIDTable): + b, err := json.Marshal([]string{o.ID}) + assert.FatalError(t, err) + assert.Equals(t, string(key), "accID") + assert.Equals(t, old, nil) + assert.Equals(t, nu, b) + return nu, true, nil + case string(orderTable): + *idptr = string(key) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) - dbo := new(dbOrder) - assert.FatalError(t, json.Unmarshal(nu, dbo)) - assert.Equals(t, dbo.ID, o.ID) - assert.Equals(t, dbo.AccountID, o.AccountID) - assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) - assert.Equals(t, dbo.CertificateID, "") - assert.Equals(t, dbo.Status, o.Status) - assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) - assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) - assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) - assert.Equals(t, dbo.NotBefore, o.NotBefore) - assert.Equals(t, dbo.NotAfter, o.NotAfter) - assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) - assert.Equals(t, dbo.Identifiers, o.Identifiers) - assert.Equals(t, dbo.Error, nil) - return nu, true, nil + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } }, }, o: o, @@ -555,3 +568,434 @@ func TestDB_CreateOrder(t *testing.T) { }) } } + +func TestDB_updateAddOrderIDs(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + addOids []string + res []string + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, errors.New("force") + }, + }, + err: errors.Errorf("error loading orderIDs for account %s", accID), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return []byte("foo"), nil + }, + }, + err: errors.Errorf("error unmarshaling orderIDs for account %s", accID), + } + }, + "fail/db.Get-order-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + b, err := json.Marshal([]string{"foo", "bar"}) + assert.FatalError(t, err) + return b, nil + case string(orderTable): + assert.Equals(t, key, []byte("foo")) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewErrorISE("error loading order foo for account accID: error loading order foo: force"), + } + }, + "fail/update-order-status-error": func(t *testing.T) test { + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + b, err := json.Marshal([]string{"foo", "bar"}) + assert.FatalError(t, err) + return b, nil + case string(orderTable): + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("foo")) + assert.Equals(t, old, bfoo) + + newdbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, newdbo)) + assert.Equals(t, newdbo.ID, "foo") + assert.Equals(t, newdbo.Status, acme.StatusInvalid) + assert.Equals(t, newdbo.ExpiresAt, expiry) + assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "order has expired").Error()) + return nil, false, errors.New("force") + }, + }, + acmeErr: acme.NewErrorISE("error updating order foo for account accID: error saving acme order: force"), + } + }, + "fail/db.save-order-error": func(t *testing.T) test { + addOids := []string{"foo", "bar"} + b, err := json.Marshal(addOids) + assert.FatalError(t, err) + delCount := 0 + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, nil) + assert.Equals(t, nu, b) + return nil, false, errors.New("force") + }, + MDel: func(bucket, key []byte) error { + delCount++ + switch delCount { + case 1: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("foo")) + return nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("bar")) + return nil + default: + assert.FatalError(t, errors.New("delete should only be called twice")) + return errors.New("force") + } + }, + }, + addOids: addOids, + err: errors.Errorf("error saving orderIDs index for account %s", accID), + } + }, + "ok/all-old-not-pending": func(t *testing.T) test { + oldOids := []string{"foo", "bar"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + obar := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bbar, err := json.Marshal(obar) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + case "bar": + assert.Equals(t, key, []byte("bar")) + return bbar, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, nil) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + res: []string{}, + } + }, + "ok/old-and-new": func(t *testing.T) test { + oldOids := []string{"foo", "bar"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + addOids := []string{"zap", "zar"} + bAddOids, err := json.Marshal(addOids) + assert.FatalError(t, err) + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + obar := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bbar, err := json.Marshal(obar) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + case "bar": + assert.Equals(t, key, []byte("bar")) + return bbar, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, bAddOids) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + addOids: addOids, + res: addOids, + } + }, + "ok/old-and-new-2": func(t *testing.T) test { + oldOids := []string{"foo", "bar", "baz"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + addOids := []string{"zap", "zar"} + now := clock.Now() + min5 := now.Add(5 * time.Minute) + expiry := now.Add(-5 * time.Minute) + + o1 := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: min5, + AuthorizationIDs: []string{"a"}, + } + bo1, err := json.Marshal(o1) + assert.FatalError(t, err) + o2 := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bo2, err := json.Marshal(o2) + assert.FatalError(t, err) + o3 := &dbOrder{ + ID: "baz", + Status: acme.StatusPending, + ExpiresAt: min5, + AuthorizationIDs: []string{"b"}, + } + bo3, err := json.Marshal(o3) + assert.FatalError(t, err) + + az1 := &dbAuthz{ + ID: "a", + Status: acme.StatusPending, + ExpiresAt: min5, + ChallengeIDs: []string{"aa"}, + } + baz1, err := json.Marshal(az1) + assert.FatalError(t, err) + az2 := &dbAuthz{ + ID: "b", + Status: acme.StatusPending, + ExpiresAt: min5, + ChallengeIDs: []string{"bb"}, + } + baz2, err := json.Marshal(az2) + assert.FatalError(t, err) + + ch1 := &dbChallenge{ + ID: "aa", + Status: acme.StatusPending, + } + bch1, err := json.Marshal(ch1) + assert.FatalError(t, err) + ch2 := &dbChallenge{ + ID: "bb", + Status: acme.StatusPending, + } + bch2, err := json.Marshal(ch2) + assert.FatalError(t, err) + + newOids := append([]string{"foo", "baz"}, addOids...) + bNewOids, err := json.Marshal(newOids) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + switch string(key) { + case "a": + return baz1, nil + case "b": + return baz2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", string(key))) + return nil, errors.New("force") + } + case string(challengeTable): + switch string(key) { + case "aa": + return bch1, nil + case "bb": + return bch2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected challenge key %s", string(key))) + return nil, errors.New("force") + } + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + return bo1, nil + case "bar": + return bo2, nil + case "baz": + return bo3, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, bNewOids) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + addOids: addOids, + res: newOids, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + var ( + res []string + err error + ) + if tc.addOids == nil { + res, err = db.updateAddOrderIDs(context.Background(), accID) + } else { + res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...) + } + + if err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.True(t, reflect.DeepEqual(res, tc.res)) + } + } + }) + } +} From 8b4a5a6d8b041bce7938a35ccfc3475d36a6e694 Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 23 Mar 2021 23:04:22 -0700 Subject: [PATCH 28/47] add unit tests for dns01 validate --- acme/challenge_test.go | 579 ++++++++++++++++++++++------------------- 1 file changed, 309 insertions(+), 270 deletions(-) diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 29bd5a71..9a637f17 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -4,10 +4,12 @@ import ( "bytes" "context" "crypto" + "crypto/sha256" "encoding/base64" "fmt" "io/ioutil" "net/http" + "strings" "testing" "time" @@ -21,7 +23,7 @@ func TestKeyAuthorization(t *testing.T) { token string jwk *jose.JSONWebKey exp string - err error + err *Error } tests := map[string]func(t *testing.T) test{ "fail/jwk-thumbprint-error": func(t *testing.T) test { @@ -31,7 +33,7 @@ func TestKeyAuthorization(t *testing.T) { return test{ token: "1234", jwk: jwk, - err: errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { @@ -53,7 +55,16 @@ func TestKeyAuthorization(t *testing.T) { tc := run(t) if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - assert.HasPrefix(t, err.Error(), tc.err.Error()) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { if assert.Nil(t, tc.err) { @@ -454,6 +465,301 @@ func TestHTTP01Validate(t *testing.T) { } } +func TestDNS01Validate(t *testing.T) { + fulldomain := "*.zap.internal" + domain := strings.TrimPrefix(fulldomain, "*.") + type test struct { + vo *ValidateChallengeOptions + ch *Challenge + jwk *jose.JSONWebKey + db DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/lookupTXT-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/lookupTXT-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo"}, nil + }, + }, + jwk: jwk, + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), + } + }, + "fail/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + jwk: jwk, + } + }, + "fail/update-challenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) + + return errors.New("force") + }, + }, + jwk: jwk, + err: NewErrorISE("error updating challenge: force"), + } + }, + "ok": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: fulldomain, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) + + return nil + }, + }, + jwk: jwk, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + /* func TestTLSALPN01Validate(t *testing.T) { type test struct { @@ -1170,273 +1476,6 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na }, nil } -func TestDNS01Validate(t *testing.T) { - type test struct { - vo validateOptions - ch challenge - res challenge - jwk *jose.JSONWebKey - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - return test{ - ch: ch, - res: ch, - } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid - return test{ - ch: ch, - res: ch, - } - }, - "ok/lookup-txt-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := DNSErr(errors.Errorf("error looking up TXT records for "+ - "domain %s: force", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &dns01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return nil, errors.New("force") - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/lookup-txt-wildcard": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Value = "*.zap.internal" - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &dns01Challenge{baseClone} - - return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - assert.Equals(t, url, "_acme-challenge.zap.internal") - return []string{"foo", expected}, nil - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - dnsCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, dnsCh.getStatus(), StatusValid) - baseClone.Validated = dnsCh.getValidated() - return nil, true, nil - }, - }, - } - }, - "fail/key-authorization-gen-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - jwk.Key = "foo" - return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", "bar"}, nil - }, - }, - jwk: jwk, - err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "ok/key-auth-mismatch": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", "bar"}, nil - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil - }, - }, - res: ch, - } - }, - "fail/save-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) - return test{ - ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", expected}, nil - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - } - }, - "ok": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &dns01Challenge{baseClone} - - return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", expected}, nil - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - - dnsCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, dnsCh.getStatus(), StatusValid) - assert.True(t, dnsCh.getValidated().Before(time.Now().UTC())) - assert.True(t, dnsCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) - - baseClone.Validated = dnsCh.getValidated() - - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } - } - }) - } -} - /* var testOps = ChallengeOptions{ AccountID: "accID", From 1fb0f1d7d9b812f6a279715877b768f1f4e81a1a Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 23 Mar 2021 23:16:17 -0700 Subject: [PATCH 29/47] add storeError unit tests --- acme/challenge_test.go | 112 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 9a637f17..29e20eaa 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -18,6 +18,118 @@ import ( "go.step.sm/crypto/jose" ) +func Test_storeError(t *testing.T) { + type test struct { + ch *Challenge + db DB + err *Error + } + err := NewError(ErrorMalformedType, "foo") + tests := map[string]func(t *testing.T) test{ + "fail/db.UpdateChallenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + return test{ + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "fail/db.UpdateChallenge-acme-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + return test{ + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return NewError(ErrorMalformedType, "bar") + }, + }, + err: NewError(ErrorMalformedType, "failure saving error to acme challenge: bar"), + } + }, + "ok": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + return test{ + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := storeError(context.Background(), tc.ch, tc.db, err); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + func TestKeyAuthorization(t *testing.T) { type test struct { token string From a8e4bbf715df11abb71cd67f1cd645ffb7050a8b Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 23 Mar 2021 23:30:59 -0700 Subject: [PATCH 30/47] start Validate unit tests --- acme/challenge_test.go | 211 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 29e20eaa..cfb990eb 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -187,6 +187,217 @@ func TestKeyAuthorization(t *testing.T) { } } +func TestChallenge_Validate(t *testing.T) { + type test struct { + ch *Challenge + vo *ValidateChallengeOptions + jwk *jose.JSONWebKey + db DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/already-valid": func(t *testing.T) test { + ch := &Challenge{ + Status: StatusValid, + } + return test{ + ch: ch, + } + }, + "fail/already-invalid": func(t *testing.T) test { + ch := &Challenge{ + Status: StatusInvalid, + } + return test{ + ch: ch, + } + }, + "fail/unexpected-type": func(t *testing.T) test { + ch := &Challenge{ + Status: StatusPending, + Type: "foo", + } + return test{ + ch: ch, + err: NewErrorISE("unexpected challenge type 'foo'"), + } + }, + "fail/http-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Status: StatusPending, + AuthzID: "azID", + Type: "http-01", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/http-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Status: StatusPending, + AuthzID: "azID", + Type: "http-01", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/dns-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Type: "dns-01", + Status: StatusPending, + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/dns-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Type: "dns-01", + Status: StatusPending, + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + type errReader int func (errReader) Read(p []byte) (n int, err error) { From a58466589f0ea56800e57c257e24df5e153b67ec Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 24 Mar 2021 14:32:18 -0700 Subject: [PATCH 31/47] add tls-alpn-01 validate unit tests --- acme/challenge_test.go | 1989 ++++++++++++++++++---------------------- 1 file changed, 896 insertions(+), 1093 deletions(-) diff --git a/acme/challenge_test.go b/acme/challenge_test.go index cfb990eb..d5b6cc58 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -4,11 +4,22 @@ import ( "bytes" "context" "crypto" + "crypto/rand" + "crypto/rsa" "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" "encoding/base64" + "encoding/hex" "fmt" + "io" "io/ioutil" + "math/big" + "net" "net/http" + "net/http/httptest" "strings" "testing" "time" @@ -193,6 +204,7 @@ func TestChallenge_Validate(t *testing.T) { vo *ValidateChallengeOptions jwk *jose.JSONWebKey db DB + srv *httptest.Server err *Error } tests := map[string]func(t *testing.T) test{ @@ -374,10 +386,97 @@ func TestChallenge_Validate(t *testing.T) { }, } }, + "fail/tls-alpn-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", + } + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/tls-alpn-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Error, nil) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) + + if tc.srv != nil { + defer tc.srv.Close() + } + if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { @@ -1083,1351 +1182,1055 @@ func TestDNS01Validate(t *testing.T) { } } -/* -func TestTLSALPN01Validate(t *testing.T) { - type test struct { - srv *httptest.Server - vo validateOptions - ch challenge - res challenge - jwk *jose.JSONWebKey - db nosql.DB - err *Error +func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { + srv := httptest.NewUnstartedServer(http.NewServeMux()) + + srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){ + "acme-tls/1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + // no-op + }, + "http/1.1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + panic("unexpected http/1.1 next proto") + }, } - tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - return test{ - ch: ch, - res: ch, + srv.TLS = &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == "acme-tls/1" { + return validationCert, nil } + return nil, nil }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid + NextProtos: []string{ + "acme-tls/1", + "http/1.1", + }, + } - return test{ - ch: ch, - res: ch, - } + srv.Listener = tls.NewListener(srv.Listener, srv.TLS) + //srv.Config.ErrorLog = log.New(ioutil.Discard, "", 0) // hush + + return srv, func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + } +} + +// noopConn is a mock net.Conn that does nothing. +type noopConn struct{} + +func (c *noopConn) Read(_ []byte) (n int, err error) { return 0, io.EOF } +func (c *noopConn) Write(_ []byte) (n int, err error) { return 0, io.EOF } +func (c *noopConn) Close() error { return nil } +func (c *noopConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } +func (c *noopConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } +func (c *noopConn) SetDeadline(t time.Time) error { return nil } +func (c *noopConn) SetReadDeadline(t time.Time) error { return nil } +func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil } + +func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, names ...string) (*tls.Certificate, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + certTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1337), + Subject: pkix.Name{ + Organization: []string{"Test"}, }, - "ok/tls-dial-error": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, 1), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: names, + } - expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: force", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + if keyAuthHash != nil { + oid := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} + if obsoleteOID { + oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} + } + + keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) + certTemplate.ExtraExtensions = []pkix.Extension{ + { + Id: oid, + Critical: critical, + Value: keyAuthHashEnc, + }, + } + } + + cert, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, privateKey.Public(), privateKey) + if err != nil { + return nil, err + } + + return &tls.Certificate{ + PrivateKey: privateKey, + Certificate: [][]byte{cert}, + }, nil +} + +func TestTLSALPN01Validate(t *testing.T) { + makeTLSCh := func() *Challenge { + return &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", + } + } + type test struct { + vo *ValidateChallengeOptions + ch *Challenge + jwk *jose.JSONWebKey + db DB + srv *httptest.Server + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/tlsDial-store-error": func(t *testing.T) test { + ch := makeTLSCh() return test{ ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/timeout": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - srv, tlsDial := newTestTLSALPNServer(nil) - // srv.Start() - do not start server to cause timeout + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value) - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/no-certificates": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("tls-alpn-01 challenge for %v resulted in no certificates", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - + "ok/tlsDial-error": func(t *testing.T) test { + ch := makeTLSCh() return test{ ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.Client(&noopConn{}, config), nil + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return nil, errors.New("force") }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, } }, - "ok/no-names": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/tlsDial-timeout": func(t *testing.T) test { + ch := makeTLSCh() - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + srv, tlsDial := newTestTLSALPNServer(nil) + // srv.Start() - do not start server to cause timeout - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) - assert.FatalError(t, err) + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value) - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + } + }, + "ok/no-certificates-error": func(t *testing.T) test { + ch := makeTLSCh() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Client(&noopConn{}, config), nil + }, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, } }, - "ok/too-many-names": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "fail/no-certificates-store-error": func(t *testing.T) test { + ch := makeTLSCh() - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Client(&noopConn{}, config), nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-no-protocol": func(t *testing.T) test { + ch := makeTLSCh() - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue(), "other.internal") + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + srv := httptest.NewTLSServer(nil) return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + }, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, + srv: srv, + jwk: jwk, } }, - "ok/wrong-name": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + "fail/no-protocol-store-error": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + srv := httptest.NewTLSServer(nil) return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + }, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/no-extension": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/no-names-error": func(t *testing.T) test { + ch := makeTLSCh() - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(nil, false, true, ch.getValue()) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, + srv: srv, + jwk: jwk, } }, - "ok/extension-not-critical": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + "fail/no-names-store-error": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.getValue()) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/extension-malformed": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/too-many-names-error": func(t *testing.T) test { + ch := makeTLSCh() - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.getValue()) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value, "other.internal") assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, + srv: srv, + jwk: jwk, } }, - "ok/no-protocol": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/wrong-name": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.New("cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - srv := httptest.NewTLSServer(nil) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) - }, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, + srv: srv, + jwk: jwk, } }, - "ok/mismatched-token": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - incorrectTokenHash := sha256.Sum256([]byte("mismatched")) + jwk.Key = "foo" + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), + } + }, + "ok/error-no-extension": func(t *testing.T) test { + ch := makeTLSCh() - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "expected acmeValidationV1 extension value %s for this challenge but got %s", - hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:]))) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.getValue()) + cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, + srv: srv, + jwk: jwk, } }, - "ok/obsolete-oid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "fail/no-extension-store-error": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: " + - "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-extension-not-critical": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.getValue()) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, + srv: srv, + jwk: jwk, } }, - "ok": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &tlsALPN01Challenge{baseClone} + "fail/extension-not-critical-store-error": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue()) + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value) assert.FatalError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { - assert.Equals(t, network, "tcp") - assert.Equals(t, addr, net.JoinHostPort(newCh.getValue(), "443")) - assert.Equals(t, config.NextProtos, []string{"acme-tls/1"}) - assert.Equals(t, config.ServerName, newCh.getValue()) - assert.True(t, config.InsecureSkipVerify) + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical") - return tlsDial(network, addr, config) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, + srv: srv, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-malformed-extension": func(t *testing.T) test { + ch := makeTLSCh() - alpnCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, alpnCh.getStatus(), StatusValid) - assert.True(t, alpnCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, alpnCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - baseClone.Validated = alpnCh.getValidated() + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value") - return nil, true, nil + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: newCh, + srv: srv, + jwk: jwk, } }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) + "fail/malformed-extension-store-error": func(t *testing.T) test { + ch := makeTLSCh() - if tc.srv != nil { - defer tc.srv.Close() - } + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } - } - }) - } -} + cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value) + assert.FatalError(t, err) -func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { - srv := httptest.NewUnstartedServer(http.NewServeMux()) + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() - srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){ - "acme-tls/1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { - // no-op - }, - "http/1.1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { - panic("unexpected http/1.1 next proto") - }, - } + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - srv.TLS = &tls.Config{ - GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == "acme-tls/1" { - return validationCert, nil + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } - return nil, nil - }, - NextProtos: []string{ - "acme-tls/1", - "http/1.1", }, - } + "ok/error-keyauth-mismatch": func(t *testing.T) test { + ch := makeTLSCh() - srv.Listener = tls.NewListener(srv.Listener, srv.TLS) - //srv.Config.ErrorLog = log.New(ioutil.Discard, "", 0) // hush + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) - return srv, func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { - return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) - } -} - -// noopConn is a mock net.Conn that does nothing. -type noopConn struct{} - -func (c *noopConn) Read(_ []byte) (n int, err error) { return 0, io.EOF } -func (c *noopConn) Write(_ []byte) (n int, err error) { return 0, io.EOF } -func (c *noopConn) Close() error { return nil } -func (c *noopConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } -func (c *noopConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } -func (c *noopConn) SetDeadline(t time.Time) error { return nil } -func (c *noopConn) SetReadDeadline(t time.Time) error { return nil } -func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil } - -func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, names ...string) (*tls.Certificate, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - - certTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1337), - Subject: pkix.Name{ - Organization: []string{"Test"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - DNSNames: names, - } - - if keyAuthHash != nil { - oid := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} - if obsoleteOID { - oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} - } - - keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) - - certTemplate.ExtraExtensions = []pkix.Extension{ - { - Id: oid, - Critical: critical, - Value: keyAuthHashEnc, - }, - } - } + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + incorrectTokenHash := sha256.Sum256([]byte("mismatched")) - cert, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, privateKey.Public(), privateKey) - if err != nil { - return nil, err - } + cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value) + assert.FatalError(t, err) - return &tls.Certificate{ - PrivateKey: privateKey, - Certificate: [][]byte{cert}, - }, nil -} + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() -/* -var testOps = ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: "zap.internal", - }, -} + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) -func newDNSCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newDNS01Challenge(mockdb, testOps) -} + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])) -func newTLSALPNCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } }, - } - return newTLSALPN01Challenge(mockdb, testOps) -} + "fail/keyauth-mismatch-store-error": func(t *testing.T) test { + ch := makeTLSCh() -func newHTTPCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, testOps) -} + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) -func newHTTPChWithServer(host string) (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: host, - }, - }) -} + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + incorrectTokenHash := sha256.Sum256([]byte("mismatched")) -func TestNewHTTP01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newHTTP01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "http-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} + cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value) + assert.FatalError(t, err) -func TestNewTLSALPN01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newTLSALPN01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "tls-alpn-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() -func TestNewDNS01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "dns", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newDNS01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "dns-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) -func TestChallengeToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Validated = clock.Now() - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - tests := map[string]challenge{ - "dns": dnsCh, - "http": httpCh, - "tls-alpn": tlsALPNCh, - } - for name, ch := range tests { - t.Run(name, func(t *testing.T) { - ach, err := ch.toACME(ctx, nil, dir) - assert.FatalError(t, err) - - assert.Equals(t, ach.Type, ch.getType()) - assert.Equals(t, ach.Status, ch.getStatus()) - assert.Equals(t, ach.Token, ch.getToken()) - assert.Equals(t, ach.URL, - fmt.Sprintf("%s/acme/%s/challenge/%s", - baseURL.String(), provName, ch.getID())) - assert.Equals(t, ach.ID, ch.getID()) - assert.Equals(t, ach.AuthzID, ch.getAuthzID()) - - if ach.Type == "http-01" { - v, err := time.Parse(time.RFC3339, ach.Validated) - assert.FatalError(t, err) - assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) - } else { - assert.Equals(t, ach.Validated, "") - } - }) - } -} + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])) -func TestChallengeSave(t *testing.T) { - type test struct { - ch challenge - old challenge - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "fail/old-nil/swap-false": func(t *testing.T) test { - httpCh, err := newHTTPCh() + "ok/error-obsolete-oid": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - httpCh, err := newHTTPCh() + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - b, err := json.Marshal(httpCh) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value) assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil - }, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldHTTPCh, err := newHTTPCh() - assert.FatalError(t, err) - httpCh, err := newHTTPCh() - assert.FatalError(t, err) + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - oldb, err := json.Marshal(oldHTTPCh) - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: oldHTTPCh, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, + srv: srv, + jwk: jwk, } }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.ch.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestChallengeClone(t *testing.T) { - ch, err := newHTTPCh() - assert.FatalError(t, err) + "fail/obsolete-oid-store-error": func(t *testing.T) test { + ch := makeTLSCh() - clone := ch.clone() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) - assert.Equals(t, clone.getID(), ch.getID()) - assert.Equals(t, clone.getAccountID(), ch.getAccountID()) - assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, clone.getStatus(), ch.getStatus()) - assert.Equals(t, clone.getToken(), ch.getToken()) - assert.Equals(t, clone.getCreated(), ch.getCreated()) - assert.Equals(t, clone.getValidated(), ch.getValidated()) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - clone.Status = StatusValid + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value) + assert.FatalError(t, err) - assert.NotEquals(t, clone.getStatus(), ch.getStatus()) -} + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() -func TestChallengeUnmarshal(t *testing.T) { - type test struct { - ch challenge - chb []byte - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { return test{ - chb: nil, - err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), - } - }, - "fail/unexpected-type-http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Type = "foo" - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _tlsALPNCh.baseChallenge.Type = "foo" - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - chb: b, - } - }, - "ok/http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - "ok/alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - ch: tlsALPNCh, - chb: b, - } - }, - "ok/err": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := unmarshalChallenge(tc.chb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } - } - }) - } -} -func TestGetChallenge(t *testing.T) { - type test struct { - id string - db nosql.DB - ch challenge - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), - } - }, - "fail/db-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, - err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok": func(t *testing.T) test { - dnsCh, err := newDNSCh() + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Error, nil) + return nil }, }, + srv: srv, + jwk: jwk, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if ch, err := getChallenge(tc.db, tc.id); err != nil { + + if tc.srv != nil { + defer tc.srv.Close() + } + + if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } + assert.Nil(t, tc.err) } }) } } -*/ From c0a9f247989504566583d1391a19cdfae76ae07f Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 24 Mar 2021 16:50:35 -0700 Subject: [PATCH 32/47] add authorization and order unit tests --- acme/account_test.go | 759 ++-------------------------------- acme/authorization.go | 1 + acme/authorization_test.go | 150 +++++++ acme/authz_test.go | 824 ------------------------------------- acme/order.go | 9 +- acme/order_test.go | 256 ++++++++++++ 6 files changed, 451 insertions(+), 1548 deletions(-) create mode 100644 acme/authorization_test.go delete mode 100644 acme/authz_test.go diff --git a/acme/account_test.go b/acme/account_test.go index 45b86f20..5625c3dc 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -1,764 +1,81 @@ package acme import ( - "fmt" - "time" + "crypto" + "encoding/base64" + "testing" - "github.com/smallstep/certificates/authority/provisioner" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "go.step.sm/crypto/jose" ) -var ( - defaultDisableRenewal = false - globalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - } -) - -func newProv() Provisioner { - // Initialize provisioners - p := &provisioner.ACME{ - Type: "ACME", - Name: "test@acme-provisioner.com", - } - if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - fmt.Printf("%v", err) - } - return p -} - -/* -func newAcc() (*Account, error) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - if err != nil { - return nil, err - } - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - } - return newAccount(mockdb, AccountOptions{ - Key: jwk, Contact: []string{"foo", "bar"}, - }) -} -*/ - -/* -func TestGetAccountByID(t *testing.T) { - type test struct { - id string - db nosql.DB - acc *Account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: NewError(ErrorMalformedType, "account %s not found: not found", acc.ID), - } - }, - "fail/db-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: NewErrorISE("error loading account %s: force", acc.ID), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acc, err := getAccountByID(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, acc.ID) - assert.Equals(t, tc.acc.Status, acc.Status) - assert.Equals(t, tc.acc.Created, acc.Created) - assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) - assert.Equals(t, tc.acc.Contact, acc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) - } - } - }) - } -} - -func TestGetAccountByKeyID(t *testing.T) { - type test struct { - kid string - db nosql.DB - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/kid-not-found": func(t *testing.T) test { - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("account with key id foo not found: not found")), - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading key-account index: force")), - } - }, - "fail/getAccount-error": func(t *testing.T) test { - count := 0 - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte("foo")) - count++ - return []byte("bar"), nil - } - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading account bar: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - kid: acc.Key.KeyID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(acc.Key.KeyID)) - ret = []byte(acc.ID) - case 1: - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - ret = b - } - count++ - return ret, nil - }, - }, - acc: acc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, acc.ID) - assert.Equals(t, tc.acc.Status, acc.Status) - assert.Equals(t, tc.acc.Created, acc.Created) - assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) - assert.Equals(t, tc.acc.Contact, acc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) - } - } - }) - } -} - -func TestAccountToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{acc: acc} - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeAccount, err := tc.acc.toACME(ctx, nil, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeAccount.ID, tc.acc.ID) - assert.Equals(t, acmeAccount.Status, tc.acc.Status) - assert.Equals(t, acmeAccount.Contact, tc.acc.Contact) - assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID) - assert.Equals(t, acmeAccount.Orders, - fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID)) - } - } - }) - } -} - -func TestAccountSave(t *testing.T) { - type test struct { - acc, old *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing account; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, accountTable) - assert.Equals(t, []byte(acc.ID), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldAcc, err := newAcc() - assert.FatalError(t, err) - acc, err := newAcc() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldAcc) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - old: oldAcc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - assert.Equals(t, bucket, accountTable) - assert.Equals(t, []byte(acc.ID), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.acc.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAccountSaveNew(t *testing.T) { +func TestKeyToID(t *testing.T) { type test struct { - acc *account - db nosql.DB + jwk *jose.JSONWebKey + exp string err *Error } tests := map[string]func(t *testing.T) test{ - "fail/keyToID-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - acc.Key.Key = "foo" - return test{ - acc: acc, - err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "fail/swap-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "fail/swap-false": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.New("key-id to account-id index already exists")), - } - }, - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - count++ - return nil, true, nil - } - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, b) - return nil, false, errors.New("force") - }, - MDel: func(bucket, key []byte) error { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - return nil - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - count++ - return nil, true, nil - } - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, b) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.acc.saveNew(tc.db); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAccountUpdate(t *testing.T) { - type test struct { - acc *account - contact []string - db nosql.DB - res []byte - err *Error - } - contact := []string{"foo", "bar"} - tests := map[string]func(t *testing.T) test{ - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Contact = contact - b, err := json.Marshal(clone) + "fail/error-generating-thumbprint": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + jwk.Key = "foo" return test{ - acc: acc, - contact: contact, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), + jwk: jwk, + err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - _acc := *acc - clone := &_acc - clone.Contact = contact - b, err := json.Marshal(clone) - assert.FatalError(t, err) - return test{ - acc: acc, - contact: contact, - res: b, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - acc, err := tc.acc.update(tc.db, tc.contact) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - b, err := json.Marshal(acc) - assert.FatalError(t, err) - assert.Equals(t, b, tc.res) - } - } - }) - } -} - -func TestAccountDeactivate(t *testing.T) { - type test struct { - acc *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) + kid, err := jwk.Thumbprint(crypto.SHA256) assert.FatalError(t, err) return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - return nil, true, nil - }, - }, + jwk: jwk, + exp: base64.RawURLEncoding.EncodeToString(kid), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - acc, err := tc.acc.deactivate(tc.db) - if err != nil { + if id, err := KeyToID(tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, tc.acc.ID) - assert.Equals(t, acc.Contact, tc.acc.Contact) - assert.Equals(t, acc.Status, StatusDeactivated) - assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID) - assert.Equals(t, acc.Created, tc.acc.Created) - - assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute))) - assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute))) + assert.Equals(t, id, tc.exp) } } }) } } -func TestNewAccount(t *testing.T) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - kid, err := keyToID(jwk) - assert.FatalError(t, err) - ops := AccountOptions{ - Key: jwk, - Contact: []string{"foo", "bar"}, - } +func TestAccount_IsValid(t *testing.T) { type test struct { - ops AccountOptions - db nosql.DB - err *Error - id *string + acc *Account + exp bool } - tests := map[string]func(t *testing.T) test{ - "fail/store-error": func(t *testing.T) test { - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "ok": func(t *testing.T) test { - var _id string - id := &_id - count := 0 - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - case 1: - assert.Equals(t, bucket, accountTable) - *id = string(key) - } - count++ - return nil, true, nil - }, - }, - id: id, - } - }, + tests := map[string]test{ + "valid": {acc: &Account{Status: StatusValid}, exp: true}, + "invalid": {acc: &Account{Status: StatusInvalid}, exp: false}, } - for name, run := range tests { - tc := run(t) + for name, tc := range tests { t.Run(name, func(t *testing.T) { - acc, err := newAccount(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, *tc.id) - assert.Equals(t, acc.Status, StatusValid) - assert.Equals(t, acc.Contact, ops.Contact) - assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID) - - assert.True(t, acc.Deactivated.IsZero()) - - assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute))) - } - } + assert.Equals(t, tc.acc.IsValid(), tc.exp) }) } } -*/ diff --git a/acme/authorization.go b/acme/authorization.go index 62bc4637..4d5c42c8 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -57,6 +57,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { return nil } az.Status = StatusValid + az.Error = nil default: return NewErrorISE("unrecognized authorization status: %s", az.Status) } diff --git a/acme/authorization_test.go b/acme/authorization_test.go new file mode 100644 index 00000000..00b35b99 --- /dev/null +++ b/acme/authorization_test.go @@ -0,0 +1,150 @@ +package acme + +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" +) + +func TestAuthorization_UpdateStatus(t *testing.T) { + type test struct { + az *Authorization + err *Error + db DB + } + tests := map[string]func(t *testing.T) test{ + "ok/already-invalid": func(t *testing.T) test { + az := &Authorization{ + Status: StatusInvalid, + } + return test{ + az: az, + } + }, + "ok/already-valid": func(t *testing.T) test { + az := &Authorization{ + Status: StatusInvalid, + } + return test{ + az: az, + } + }, + "fail/error-unexpected-status": func(t *testing.T) test { + az := &Authorization{ + Status: "foo", + } + return test{ + az: az, + err: NewErrorISE("unrecognized authorization status: %s", az.Status), + } + }, + "ok/expired": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusInvalid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + return nil + }, + }, + } + }, + "fail/db.UpdateAuthorization-error": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusInvalid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating authorization: force"), + } + }, + "ok/no-valid-challenges": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*Challenge{ + {Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending}, + }, + } + return test{ + az: az, + } + }, + "ok/valid": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*Challenge{ + {Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid}, + }, + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusValid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + assert.Equals(t, updaz.Error, nil) + return nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + + } +} diff --git a/acme/authz_test.go b/acme/authz_test.go deleted file mode 100644 index 206921c6..00000000 --- a/acme/authz_test.go +++ /dev/null @@ -1,824 +0,0 @@ -package acme - -/* -func newAz() (*Authorization, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newAuthz(mockdb, "1234", Identifier{ - Type: "dns", Value: "acme.example.com", - }) -} - -func TestGetAuthz(t *testing.T) { - type test struct { - id string - db nosql.DB - az authz - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())), - } - }, - "fail/db-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Identifier.Type = "foo" - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, - err: ServerInternalErr(errors.New("unexpected authz type foo")), - } - }, - "ok": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if az, err := getAuthz(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.az.getID(), az.getID()) - assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) - assert.Equals(t, tc.az.getStatus(), az.getStatus()) - assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier()) - assert.Equals(t, tc.az.getCreated(), az.getCreated()) - assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) - assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) - } - } - }) - } -} - -func TestAuthzClone(t *testing.T) { - az, err := newAz() - assert.FatalError(t, err) - - clone := az.clone() - - assert.Equals(t, clone.getID(), az.getID()) - assert.Equals(t, clone.getAccountID(), az.getAccountID()) - assert.Equals(t, clone.getStatus(), az.getStatus()) - assert.Equals(t, clone.getIdentifier(), az.getIdentifier()) - assert.Equals(t, clone.getExpiry(), az.getExpiry()) - assert.Equals(t, clone.getCreated(), az.getCreated()) - assert.Equals(t, clone.getChallenges(), az.getChallenges()) - - clone.Status = StatusValid - - assert.NotEquals(t, clone.getStatus(), az.getStatus()) -} - -func TestNewAuthz(t *testing.T) { - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - accID := "1234" - type test struct { - iden Identifier - db nosql.DB - err *Error - resChs *([]string) - } - tests := map[string]func(t *testing.T) test{ - "fail/unexpected-type": func(t *testing.T) test { - return test{ - iden: Identifier{Type: "foo", Value: "acme.example.com"}, - err: MalformedErr(errors.New("unexpected authz type foo")), - } - }, - "fail/new-http-chall-error": func(t *testing.T) test { - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")), - } - }, - "fail/new-tls-alpn-chall-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")), - } - }, - "fail/new-dns-chall-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 2 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")), - } - }, - "fail/save-authz-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 3 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "ok": func(t *testing.T) test { - chs := &([]string{}) - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 3 { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, old, nil) - - az, err := unmarshalAuthz(newval) - assert.FatalError(t, err) - - assert.Equals(t, az.getID(), string(key)) - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getStatus(), StatusPending) - assert.Equals(t, az.getIdentifier(), iden) - assert.Equals(t, az.getWildcard(), false) - - *chs = az.getChallenges() - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - } - count++ - return nil, true, nil - }, - }, - resChs: chs, - } - }, - "ok/wildcard": func(t *testing.T) test { - chs := &([]string{}) - count := 0 - _iden := Identifier{Type: "dns", Value: "*.acme.example.com"} - return test{ - iden: _iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, old, nil) - - az, err := unmarshalAuthz(newval) - assert.FatalError(t, err) - - assert.Equals(t, az.getID(), string(key)) - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getStatus(), StatusPending) - assert.Equals(t, az.getIdentifier(), iden) - assert.Equals(t, az.getWildcard(), true) - - *chs = az.getChallenges() - // Verify that we only have 1 challenge instead of 2. - assert.True(t, len(*chs) == 1) - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - } - count++ - return nil, true, nil - }, - }, - resChs: chs, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - az, err := newAuthz(tc.db, accID, tc.iden) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getType(), "dns") - assert.Equals(t, az.getStatus(), StatusPending) - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - - assert.Equals(t, az.getChallenges(), *(tc.resChs)) - - if strings.HasPrefix(tc.iden.Value, "*.") { - assert.True(t, az.getWildcard()) - assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*.")) - } else { - assert.False(t, az.getWildcard()) - assert.Equals(t, az.getIdentifier().Value, tc.iden.Value) - } - - assert.True(t, az.getID() != "") - } - } - }) - } -} - -func TestAuthzToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - var ( - ch1, ch2 challenge - ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) - err error - ) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - ch1, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } else if count == 1 { - *ch2Bytes = newval - ch2, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } - count++ - return []byte("foo"), true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - - type test struct { - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getChallenge1-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "fail/getChallenge2-error": func(t *testing.T) test { - count := 0 - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 1 { - return nil, errors.New("force") - } - count++ - return *ch1Bytes, nil - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "ok": func(t *testing.T) test { - count := 0 - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - return *ch2Bytes, nil - }, - }, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeAz, err := az.toACME(ctx, tc.db, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeAz.ID, az.getID()) - assert.Equals(t, acmeAz.Identifier, iden) - assert.Equals(t, acmeAz.Status, StatusPending) - - acmeCh1, err := ch1.toACME(ctx, nil, dir) - assert.FatalError(t, err) - acmeCh2, err := ch2.toACME(ctx, nil, dir) - assert.FatalError(t, err) - - assert.Equals(t, acmeAz.Challenges[0], acmeCh1) - assert.Equals(t, acmeAz.Challenges[1], acmeCh2) - - expiry, err := time.Parse(time.RFC3339, acmeAz.Expires) - assert.FatalError(t, err) - assert.Equals(t, expiry.String(), az.getExpiry().String()) - } - } - }) - } -} - -func TestAuthzSave(t *testing.T) { - type test struct { - az, old authz - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, authzTable) - assert.Equals(t, []byte(az.getID()), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldAz, err := newAz() - assert.FatalError(t, err) - az, err := newAz() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldAz) - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - old: oldAz, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, authzTable) - assert.Equals(t, []byte(az.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.az.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAuthzUnmarshal(t *testing.T) { - type test struct { - az authz - azb []byte - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { - return test{ - azb: nil, - err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")), - } - }, - "fail/unexpected-type": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Identifier.Type = "foo" - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - azb: b, - err: ServerInternalErr(errors.New("unexpected authz type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - azb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if az, err := unmarshalAuthz(tc.azb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.az.getID(), az.getID()) - assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) - assert.Equals(t, tc.az.getStatus(), az.getStatus()) - assert.Equals(t, tc.az.getCreated(), az.getCreated()) - assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) - assert.Equals(t, tc.az.getWildcard(), az.getWildcard()) - assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) - } - } - }) - } -} - -func TestAuthzUpdateStatus(t *testing.T) { - type test struct { - az, res authz - err *Error - db nosql.DB - } - tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusInvalid - return test{ - az: az, - res: az, - } - }, - "fail/already-valid": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusValid - return test{ - az: az, - res: az, - } - }, - "fail/unexpected-status": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusReady - return test{ - az: az, - res: az, - err: ServerInternalErr(errors.New("unrecognized authz status: ready")), - } - }, - "fail/save-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "ok/expired": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) - - clone := az.clone() - clone.Error = MalformedErr(errors.New("authz has expired")) - clone.Status = StatusInvalid - return test{ - az: az, - res: clone.parent(), - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "ok/valid": func(t *testing.T) test { - var ( - ch3 challenge - ch2Bytes = &([]byte{}) - ch1Bytes = &([]byte{}) - err error - ) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - } else if count == 1 { - *ch2Bytes = newval - } else if count == 2 { - ch3, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } - count++ - return nil, true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Error = MalformedErr(nil) - - _ch, ok := ch3.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - chb, err := json.Marshal(ch3) - - clone := az.clone() - clone.Status = StatusValid - clone.Error = nil - - count = 0 - return test{ - az: az, - res: clone.parent(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - if count == 1 { - count++ - return *ch2Bytes, nil - } - count++ - return chb, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "ok/still-pending": func(t *testing.T) test { - var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - } else if count == 1 { - *ch2Bytes = newval - } - count++ - return nil, true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - - count = 0 - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - count++ - return *ch2Bytes, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - az, err := tc.az.updateStatus(tc.db) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } - } - }) - } -} -*/ diff --git a/acme/order.go b/acme/order.go index f62e3354..400a4ce2 100644 --- a/acme/order.go +++ b/acme/order.go @@ -81,10 +81,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { for _, azID := range o.AuthorizationIDs { az, err := db.GetAuthorization(ctx, azID) if err != nil { - return err + return WrapErrorISE(err, "error getting authorization ID %s", azID) } if err = az.UpdateStatus(ctx, db); err != nil { - return err + return WrapErrorISE(err, "error updating authorization ID %s", azID) } st := az.Status count[st]++ @@ -107,7 +107,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { default: return NewErrorISE("unrecognized order status: %s", o.Status) } - return db.UpdateOrder(ctx, o) + if err := db.UpdateOrder(ctx, o); err != nil { + return WrapErrorISE(err, "error updating order") + } + return nil } // Finalize signs a certificate if the necessary conditions for Order completion diff --git a/acme/order_test.go b/acme/order_test.go index 5bd21fdb..d86afeb5 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -1,5 +1,261 @@ package acme +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" +) + +func TestOrder_UpdateStatus(t *testing.T) { + type test struct { + o *Order + err *Error + db DB + } + tests := map[string]func(t *testing.T) test{ + "ok/already-invalid": func(t *testing.T) test { + o := &Order{ + Status: StatusInvalid, + } + return test{ + o: o, + } + }, + "ok/already-valid": func(t *testing.T) test { + o := &Order{ + Status: StatusInvalid, + } + return test{ + o: o, + } + }, + "fail/error-unexpected-status": func(t *testing.T) test { + o := &Order{ + Status: "foo", + } + return test{ + o: o, + err: NewErrorISE("unrecognized order status: %s", o.Status), + } + }, + "ok/ready-expired": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + }, + } + }, + "fail/ready-expired-db.UpdateOrder-error": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating order: force"), + } + }, + "ok/pending-expired": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + + err := NewError(ErrorMalformedType, "order has expired") + assert.HasPrefix(t, updo.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updo.Error.Type, err.Type) + assert.Equals(t, updo.Error.Detail, err.Detail) + assert.Equals(t, updo.Error.Status, err.Status) + assert.Equals(t, updo.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "ok/invalid": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusInvalid, + } + + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + "ok/still-pending": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusPending, + } + + return test{ + o: o, + db: &MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + "ok/valid": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusValid, + } + + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusReady) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + + } +} + /* var certDuration = 6 * time.Hour From bdf4c0f836951e9493103afa710f7889593604fe Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 24 Mar 2021 20:07:21 -0700 Subject: [PATCH 33/47] add acme order unit tests --- acme/order.go | 7 +- acme/order_test.go | 1861 ++++++++------------------------------------ 2 files changed, 317 insertions(+), 1551 deletions(-) diff --git a/acme/order.go b/acme/order.go index 400a4ce2..fdc250ea 100644 --- a/acme/order.go +++ b/acme/order.go @@ -204,12 +204,15 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques Intermediates: certChain[1:], } if err := db.CreateCertificate(ctx, cert); err != nil { - return err + return WrapErrorISE(err, "error creating certificate for order %s", o.ID) } o.CertificateID = cert.ID o.Status = StatusValid - return db.UpdateOrder(ctx, o) + if err = db.UpdateOrder(ctx, o); err != nil { + return WrapErrorISE(err, "error updating order %s", o.ID) + } + return nil } // uniqueSortedLowerNames returns the set of all unique names in the input after all diff --git a/acme/order_test.go b/acme/order_test.go index d86afeb5..993a92f2 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -2,11 +2,15 @@ package acme import ( "context" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" ) func TestOrder_UpdateStatus(t *testing.T) { @@ -256,882 +260,6 @@ func TestOrder_UpdateStatus(t *testing.T) { } } -/* -var certDuration = 6 * time.Hour - -func defaultOrderOps() OrderOptions { - return OrderOptions{ - AccountID: "accID", - Identifiers: []Identifier{ - {Type: "dns", Value: "acme.example.com"}, - {Type: "dns", Value: "step.example.com"}, - }, - NotBefore: clock.Now(), - NotAfter: clock.Now().Add(certDuration), - } -} - -func newO() (*order, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - } - return newOrder(mockdb, defaultOrderOps()) -} - -func Test_getOrder(t *testing.T) { - type test struct { - id string - db nosql.DB - o *order - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("order %s not found: not found", o.ID)), - } - }, - "fail/db-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading order %s: force", o.ID)), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling order: unexpected end of JSON input")), - } - }, - "ok": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if o, err := getOrder(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.o.ID, o.ID) - assert.Equals(t, tc.o.AccountID, o.AccountID) - assert.Equals(t, tc.o.Status, o.Status) - assert.Equals(t, tc.o.Identifiers, o.Identifiers) - assert.Equals(t, tc.o.Created, o.Created) - assert.Equals(t, tc.o.Expires, o.Expires) - assert.Equals(t, tc.o.Authorizations, o.Authorizations) - assert.Equals(t, tc.o.NotBefore, o.NotBefore) - assert.Equals(t, tc.o.NotAfter, o.NotAfter) - assert.Equals(t, tc.o.Certificate, o.Certificate) - assert.Equals(t, tc.o.Error, o.Error) - } - } - }) - } -} - -func TestOrderToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - o *order - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/no-cert": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{o: o} - }, - "ok/cert": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - o.Certificate = "cert-id" - return test{o: o} - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeOrder, err := tc.o.toACME(ctx, nil, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeOrder.ID, tc.o.ID) - assert.Equals(t, acmeOrder.Status, tc.o.Status) - assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers) - assert.Equals(t, acmeOrder.Finalize, - fmt.Sprintf("%s/acme/%s/order/%s/finalize", baseURL.String(), provName, tc.o.ID)) - if tc.o.Certificate != "" { - assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, tc.o.Certificate)) - } - - expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires) - assert.FatalError(t, err) - assert.Equals(t, expiry.String(), tc.o.Expires.String()) - nbf, err := time.Parse(time.RFC3339, acmeOrder.NotBefore) - assert.FatalError(t, err) - assert.Equals(t, nbf.String(), tc.o.NotBefore.String()) - naf, err := time.Parse(time.RFC3339, acmeOrder.NotAfter) - assert.FatalError(t, err) - assert.Equals(t, naf.String(), tc.o.NotAfter.String()) - } - } - }) - } -} - -func TestOrderSave(t *testing.T) { - type test struct { - o, old *order - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing order: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing order; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - return test{ - o: o, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, orderTable) - assert.Equals(t, []byte(o.ID), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldo, err := newO() - assert.FatalError(t, err) - o, err := newO() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldo) - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - return test{ - o: o, - old: oldo, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, orderTable) - assert.Equals(t, []byte(o.ID), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.o.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func Test_newOrder(t *testing.T) { - type test struct { - ops OrderOptions - db nosql.DB - err *Error - authzs *([]string) - } - tests := map[string]func(t *testing.T) test{ - "fail/unexpected-identifier-type": func(t *testing.T) test { - ops := defaultOrderOps() - ops.Identifiers[0].Type = "foo" - return test{ - ops: ops, - err: MalformedErr(errors.New("unexpected authz type foo")), - } - }, - "fail/save-order-error": func(t *testing.T) test { - count := 0 - return test{ - ops: defaultOrderOps(), - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 8 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error storing order: force")), - } - }, - "fail/get-orderIDs-error": func(t *testing.T) test { - count := 0 - ops := defaultOrderOps() - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading orderIDs for account %s: force", ops.AccountID)), - } - }, - "fail/save-orderIDs-error": func(t *testing.T) test { - count := 0 - var ( - _oid = "" - oid = &_oid - ) - ops := defaultOrderOps() - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(ops.AccountID)) - return nil, false, errors.New("force") - } else if count == 8 { - *oid = string(key) - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - MDel: func(bucket, key []byte) error { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(*oid)) - return nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", ops.AccountID)), - } - }, - "ok": func(t *testing.T) test { - count := 0 - authzs := &([]string{}) - var ( - _oid = "" - oid = &_oid - ) - ops := defaultOrderOps() - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(ops.AccountID)) - assert.Equals(t, old, nil) - newB, err := json.Marshal([]string{*oid}) - assert.FatalError(t, err) - assert.Equals(t, newval, newB) - } else if count == 8 { - *oid = string(key) - } else if count == 7 { - *authzs = append(*authzs, string(key)) - } else if count == 3 { - *authzs = []string{string(key)} - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - authzs: authzs, - } - }, - "ok/validity-bounds-not-set": func(t *testing.T) test { - count := 0 - authzs := &([]string{}) - var ( - _oid = "" - oid = &_oid - ) - ops := defaultOrderOps() - ops.backdate = time.Minute - ops.defaultDuration = 12 * time.Hour - ops.NotBefore = time.Time{} - ops.NotAfter = time.Time{} - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(ops.AccountID)) - assert.Equals(t, old, nil) - newB, err := json.Marshal([]string{*oid}) - assert.FatalError(t, err) - assert.Equals(t, newval, newB) - } else if count == 8 { - *oid = string(key) - } else if count == 7 { - *authzs = append(*authzs, string(key)) - } else if count == 3 { - *authzs = []string{string(key)} - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - authzs: authzs, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - o, err := newOrder(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, o.AccountID, tc.ops.AccountID) - assert.Equals(t, o.Status, StatusPending) - assert.Equals(t, o.Identifiers, tc.ops.Identifiers) - assert.Equals(t, o.Error, nil) - assert.Equals(t, o.Certificate, "") - assert.Equals(t, o.Authorizations, *tc.authzs) - - assert.True(t, o.Created.Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, o.Created.After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := o.Created.Add(defaultExpiryDuration) - assert.True(t, o.Expires.Before(expiry.Add(time.Minute))) - assert.True(t, o.Expires.After(expiry.Add(-1*time.Minute))) - - nbf := tc.ops.NotBefore - now := time.Now().UTC() - if !tc.ops.NotBefore.IsZero() { - assert.Equals(t, o.NotBefore, tc.ops.NotBefore) - } else { - nbf = o.NotBefore.Add(tc.ops.backdate) - assert.True(t, o.NotBefore.Before(now.Add(-tc.ops.backdate+time.Second))) - assert.True(t, o.NotBefore.Add(tc.ops.backdate+2*time.Second).After(now)) - } - if !tc.ops.NotAfter.IsZero() { - assert.Equals(t, o.NotAfter, tc.ops.NotAfter) - } else { - naf := nbf.Add(tc.ops.defaultDuration) - assert.Equals(t, o.NotAfter, naf) - } - } - } - }) - } -} - -func TestOrderIDs_save(t *testing.T) { - accID := "acc-id" - newOids := func() orderIDs { - return []string{"1", "2"} - } - type test struct { - oids, old orderIDs - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - return test{ - oids: newOids(), - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", accID)), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - return test{ - oids: newOids(), - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s; order IDs changed since last read", accID)), - } - }, - "ok/old-nil": func(t *testing.T) test { - oids := newOids() - b, err := json.Marshal(oids) - assert.FatalError(t, err) - return test{ - oids: oids, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldOids := newOids() - oids := append(oldOids, "3") - - oldb, err := json.Marshal(oldOids) - assert.FatalError(t, err) - b, err := json.Marshal(oids) - assert.FatalError(t, err) - return test{ - oids: oids, - old: oldOids, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - return nil, true, nil - }, - }, - } - }, - "ok/new-empty-saved-as-nil": func(t *testing.T) test { - oldOids := newOids() - oids := []string{} - - oldb, err := json.Marshal(oldOids) - assert.FatalError(t, err) - return test{ - oids: oids, - old: oldOids, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, nil) - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.oids.save(tc.db, tc.old, accID); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestOrderUpdateStatus(t *testing.T) { - type test struct { - o, res *order - err *Error - db nosql.DB - } - tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusInvalid - return test{ - o: o, - res: o, - } - }, - "fail/already-valid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - return test{ - o: o, - res: o, - } - }, - "fail/unexpected-status": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusDeactivated - return test{ - o: o, - res: o, - err: ServerInternalErr(errors.New("unrecognized order status: deactivated")), - } - }, - "fail/save-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Expires = time.Now().UTC().Add(-time.Minute) - return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing order: force")), - } - }, - "ok/expired": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Expires = time.Now().UTC().Add(-time.Minute) - - _o := *o - clone := &_o - clone.Error = MalformedErr(errors.New("order has expired")) - clone.Status = StatusInvalid - return test{ - o: o, - res: clone, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "fail/get-authz-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading authz")), - } - }, - "ok/still-pending": func(t *testing.T) test { - az1, err := newAz() - assert.FatalError(t, err) - az2, err := newAz() - assert.FatalError(t, err) - az3, err := newAz() - assert.FatalError(t, err) - - ch1, err := newHTTPCh() - assert.FatalError(t, err) - ch2, err := newTLSALPNCh() - assert.FatalError(t, err) - ch3, err := newDNSCh() - assert.FatalError(t, err) - - ch1b, err := json.Marshal(ch1) - assert.FatalError(t, err) - ch2b, err := json.Marshal(ch2) - assert.FatalError(t, err) - ch3b, err := json.Marshal(ch3) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} - - _az3, ok := az3.(*dnsAuthz) - assert.Fatal(t, ok) - _az3.baseAuthz.Status = StatusValid - - b1, err := json.Marshal(az1) - assert.FatalError(t, err) - b2, err := json.Marshal(az2) - assert.FatalError(t, err) - b3, err := json.Marshal(az3) - assert.FatalError(t, err) - - count := 0 - return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - ret = b1 - case 1: - ret = ch1b - case 2: - ret = ch2b - case 3: - ret = ch3b - case 4: - ret = b2 - case 5: - ret = ch1b - case 6: - ret = ch2b - case 7: - ret = ch3b - case 8: - ret = b3 - default: - return nil, errors.New("unexpected count") - } - count++ - return ret, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "ok/invalid": func(t *testing.T) test { - az1, err := newAz() - assert.FatalError(t, err) - az2, err := newAz() - assert.FatalError(t, err) - az3, err := newAz() - assert.FatalError(t, err) - - ch1, err := newHTTPCh() - assert.FatalError(t, err) - ch2, err := newTLSALPNCh() - assert.FatalError(t, err) - ch3, err := newDNSCh() - assert.FatalError(t, err) - - ch1b, err := json.Marshal(ch1) - assert.FatalError(t, err) - ch2b, err := json.Marshal(ch2) - assert.FatalError(t, err) - ch3b, err := json.Marshal(ch3) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} - - _az3, ok := az3.(*dnsAuthz) - assert.Fatal(t, ok) - _az3.baseAuthz.Status = StatusInvalid - - b1, err := json.Marshal(az1) - assert.FatalError(t, err) - b2, err := json.Marshal(az2) - assert.FatalError(t, err) - b3, err := json.Marshal(az3) - assert.FatalError(t, err) - - _o := *o - clone := &_o - clone.Status = StatusInvalid - - count := 0 - return test{ - o: o, - res: clone, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - ret = b1 - case 1: - ret = ch1b - case 2: - ret = ch2b - case 3: - ret = ch3b - case 4: - ret = b2 - case 5: - ret = ch1b - case 6: - ret = ch2b - case 7: - ret = ch3b - case 8: - ret = b3 - default: - return nil, errors.New("unexpected count") - } - count++ - return ret, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - o, err := tc.o.updateStatus(tc.db) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } - } - }) - } -} - type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) loadProvisionerByID func(string) (provisioner.Interface, error) @@ -1155,822 +283,457 @@ func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, er return m.ret1.(provisioner.Interface), m.err } -func TestOrderFinalize(t *testing.T) { - prov := newProv() +func TestOrder_Finalize(t *testing.T) { type test struct { - o, res *order - err *Error - db nosql.DB - csr *x509.CertificateRequest - sa SignAuthority - prov Provisioner + o *Order + err *Error + db DB + ca CertificateAuthority + csr *x509.CertificateRequest + prov Provisioner } tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusInvalid - return test{ - o: o, - err: OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)), + "fail/invalid": func(t *testing.T) test { + o := &Order{ + ID: "oid", + Status: StatusInvalid, } - }, - "ok/already-valid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - o.Certificate = "cert-id" return test{ o: o, - res: o, + err: NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID), } }, - "fail/still-pending": func(t *testing.T) test { - az1, err := newAz() - assert.FatalError(t, err) - az2, err := newAz() - assert.FatalError(t, err) - az3, err := newAz() - assert.FatalError(t, err) - - ch1, err := newHTTPCh() - assert.FatalError(t, err) - ch2, err := newTLSALPNCh() - assert.FatalError(t, err) - ch3, err := newDNSCh() - assert.FatalError(t, err) - - ch1b, err := json.Marshal(ch1) - assert.FatalError(t, err) - ch2b, err := json.Marshal(ch2) - assert.FatalError(t, err) - ch3b, err := json.Marshal(ch3) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} - - _az3, ok := az3.(*dnsAuthz) - assert.Fatal(t, ok) - _az3.baseAuthz.Status = StatusValid - - b1, err := json.Marshal(az1) - assert.FatalError(t, err) - b2, err := json.Marshal(az2) - assert.FatalError(t, err) - b3, err := json.Marshal(az3) - assert.FatalError(t, err) + "fail/pending": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + } - count := 0 return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - ret = b1 - case 1: - ret = ch1b - case 2: - ret = ch2b - case 3: - ret = ch3b - case 4: - ret = b2 - case 5: - ret = ch1b - case 6: - ret = ch2b - case 7: - ret = ch3b - case 8: - ret = b3 + o: o, + db: &MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil default: - return nil, errors.New("unexpected count") + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") } - count++ - return ret, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil }, }, - err: OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)), - } - }, - "fail/ready/csr-names-match-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"acme.example.com", "fail.smallstep.com"}, - } - return test{ - o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), - } - }, - "fail/ready/csr-names-match-error-2": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "", - }, - DNSNames: []string{"acme.example.com"}, - } - return test{ - o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + err: NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID), } }, - "fail/ready/no-ipAddresses": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "", - }, - // DNSNames: []string{"acme.example.com", "step.example.com"}, - IPAddresses: []net.IP{net.ParseIP("1.1.1.1")}, + "ok/already-valid": func(t *testing.T) test { + o := &Order{ + ID: "oid", + Status: StatusValid, } return test{ - o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + o: o, } }, - "fail/ready/no-emailAddresses": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "", - }, - // DNSNames: []string{"acme.example.com", "step.example.com"}, - EmailAddresses: []string{"max@smallstep.com", "mariano@smallstep.com"}, - } - return test{ - o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + "fail/error-unexpected-status": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: "foo", + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, } - }, - "fail/ready/no-URIs": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - u, err := url.Parse("https://google.com") - assert.FatalError(t, err) - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "", - }, - // DNSNames: []string{"acme.example.com", "step.example.com"}, - URIs: []*url.URL{u}, - } return test{ o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + err: NewErrorISE("unrecognized order status: %s", o.Status), } }, - "fail/ready/provisioner-auth-sign-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - + "fail/error-names-length-mismatch": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } + orderNames := []string{"bar.internal", "foo.internal"} csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "acme.example.com", + CommonName: "foo.internal", }, - DNSNames: []string{"step.example.com", "acme.example.com"}, } + return test{ o: o, csr: csr, - err: ServerInternalErr(errors.New("error retrieving authorization options from ACME provisioner: force")), - prov: &MockProvisioner{ - MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { - return nil, errors.New("force") - }, - }, + err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", []string{"foo.internal"}, orderNames), } }, - "fail/ready/sign-cert-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - + "fail/error-names-mismatch": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } + orderNames := []string{"bar.internal", "foo.internal"} csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "acme.example.com", + CommonName: "foo.internal", }, - DNSNames: []string{"step.example.com", "acme.example.com"}, + DNSNames: []string{"zap.internal"}, } + return test{ o: o, csr: csr, - err: ServerInternalErr(errors.Errorf("error generating certificate for order %s: force", o.ID)), - sa: &mockSignAuth{ - err: errors.New("force"), - }, + err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, orderNames), } }, - "fail/ready/store-cert-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"step.example.com", "acme.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", + "fail/error-provisioner-auth": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, } - inter := &x509.Certificate{ + csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "intermediate", + CommonName: "foo.internal", }, + DNSNames: []string{"bar.internal"}, } + return test{ o: o, csr: csr, - err: ServerInternalErr(errors.Errorf("error storing certificate: force")), - sa: &mockSignAuth{ - ret1: crt, ret2: inter, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, errors.New("force") }, }, + err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"), } }, - "fail/ready/store-order-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"acme.example.com", "step.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", + "fail/error-template-options": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, } - inter := &x509.Certificate{ + csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "intermediate", + CommonName: "foo.internal", }, + DNSNames: []string{"bar.internal"}, } - count := 0 + return test{ o: o, csr: csr, - err: ServerInternalErr(errors.Errorf("error storing order: force")), - sa: &mockSignAuth{ - ret1: crt, ret2: inter, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - return nil, false, errors.New("force") + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil + }, + MgetOptions: func() *provisioner.Options { + return &provisioner.Options{ + X509: &provisioner.X509Options{ + TemplateData: json.RawMessage([]byte("fo{o")), + }, } - count++ - return nil, true, nil }, }, + err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"), } }, - "ok/ready/sign": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"acme.example.com", "step.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", + "fail/error-ca-sign": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, } - inter := &x509.Certificate{ + csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "intermediate", + CommonName: "foo.internal", }, + DNSNames: []string{"bar.internal"}, } - _o := *o - clone := &_o - clone.Status = StatusValid - - count := 0 return test{ o: o, - res: clone, csr: csr, - sa: &mockSignAuth{ - sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { - assert.Equals(t, len(signOps), 6) - return []*x509.Certificate{crt, inter}, nil + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil + }, + MgetOptions: func() *provisioner.Options { + return nil }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - clone.Certificate = string(key) - } - count++ - return nil, true, nil + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return nil, errors.New("force") }, }, + err: NewErrorISE("error signing certificate for order oID: force"), } }, - "ok/ready/no-sans": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - o.Identifiers = []Identifier{ - {Type: "dns", Value: "step.example.com"}, - } - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "step.example.com", - }, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "step.example.com", + "fail/error-db.CreateCertificate": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, - DNSNames: []string{"step.example.com"}, } - inter := &x509.Certificate{ + csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "intermediate", + CommonName: "foo.internal", }, + DNSNames: []string{"bar.internal"}, } - clone := *o - clone.Status = StatusValid - count := 0 + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} + return test{ o: o, - res: &clone, csr: csr, - sa: &mockSignAuth{ - sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { - assert.Equals(t, len(signOps), 6) - return []*x509.Certificate{crt, inter}, nil + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil + }, + MgetOptions: func() *provisioner.Options { + return nil }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - clone.Certificate = string(key) - } - count++ - return nil, true, nil + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil }, }, - } - }, - "ok/ready/sans-and-name": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return errors.New("force") + }, }, - DNSNames: []string{"step.example.com"}, + err: NewErrorISE("error creating certificate for order oID: force"), } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", + }, + "fail/error-db.UpdateOrder": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, - DNSNames: []string{"acme.example.com", "step.example.com"}, } - inter := &x509.Certificate{ + csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "intermediate", + CommonName: "foo.internal", }, + DNSNames: []string{"bar.internal"}, } - clone := *o - clone.Status = StatusValid - count := 0 + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} + return test{ o: o, - res: &clone, csr: csr, - sa: &mockSignAuth{ - sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { - assert.Equals(t, len(signOps), 6) - return []*x509.Certificate{crt, inter}, nil - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - clone.Certificate = string(key) - } - count++ - return nil, true, nil + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - p := tc.prov - if p == nil { - p = prov - } - o, err := tc.o.finalize(tc.db, tc.csr, tc.sa, p) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } - } - }) - } -} - -func Test_getOrderIDsByAccount(t *testing.T) { - type test struct { - id string - db nosql.DB - res []string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/not-found": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound + MgetOptions: func() *provisioner.Options { + return nil }, }, - res: []string{}, - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil }, }, - err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, nil + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + cert.ID = "certID" + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return nil }, - }, - err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), - } - }, - "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.CertificateID, "certID") + assert.Equals(t, updo.Status, StatusValid) + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, updo.Identifiers, o.Identifiers) + return errors.New("force") }, }, - err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), + err: NewErrorISE("error updating order oID: force"), } }, - "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return bo, nil - case 3: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(o.Authorizations[0])) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } - }, + "ok/new-cert": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, - err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), } - }, - "ok/no-change-to-pending-orders": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("should not be attempting to store anything") - }, + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", }, - res: oids, + DNSNames: []string{"bar.internal"}, } - }, - "fail/error-storing-new-oids": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, false, errors.New("force") + MgetOptions: func() *provisioner.Options { + return nil }, }, - err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), - } - }, - "ok": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3", "o4"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 || dbGetOrder == 3 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, true, nil + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil }, }, - res: []string{"o2", "o4"}, - } - }, - "ok/no-pending-orders": func(t *testing.T) test { - oids := []string{"o1"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return binvalidOrder, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + cert.ID = "certID" + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return nil }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - assert.Equals(t, old, boids) - assert.Nil(t, newval) - return nil, true, nil + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.CertificateID, "certID") + assert.Equals(t, updo.Status, StatusValid) + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, updo.Identifiers, o.Identifiers) + return nil }, }, - res: []string{}, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - var oiba = orderIDsByAccount{} - if oids, err := oiba.unsafeGetOrderIDsByAccount(tc.db, tc.id); err != nil { + if err := tc.o.Finalize(context.Background(), tc.db, tc.csr, tc.ca, tc.prov); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res, oids) - } + assert.Nil(t, tc.err) } }) } } -*/ From df05340521d0e5fed8726cf7e386b45ab9e17447 Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 24 Mar 2021 20:46:02 -0700 Subject: [PATCH 34/47] fixing broken unit tests --- acme/api/handler_test.go | 8 ++++---- acme/api/order_test.go | 2 ++ acme/db/nosql/order_test.go | 2 +- ca/acmeClient_test.go | 11 ++++++----- ca/ca.go | 12 +++++++++--- 5 files changed, 22 insertions(+), 13 deletions(-) diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 7fd8e110..19e5da76 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -478,8 +478,8 @@ func TestHandler_GetChallenge(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, - statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), + statusCode: 500, + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { @@ -489,8 +489,8 @@ func TestHandler_GetChallenge(t *testing.T) { ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, - statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), + statusCode: 500, + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/db.GetChallenge-error": func(t *testing.T) test { diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 0bc3caab..84136fa3 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -374,6 +374,7 @@ func TestHandler_GetOrder(t *testing.T) { } } +/* func TestHandler_NewOrder(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) nbf := time.Now().UTC().Add(5 * time.Hour) @@ -588,6 +589,7 @@ func TestHandler_NewOrder(t *testing.T) { }) } } +*/ func TestHandler_FinalizeOrder(t *testing.T) { now := clock.Now() diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go index 3636837c..746066a2 100644 --- a/acme/db/nosql/order_test.go +++ b/acme/db/nosql/order_test.go @@ -665,7 +665,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) { return nil, false, errors.New("force") }, }, - acmeErr: acme.NewErrorISE("error updating order foo for account accID: error saving acme order: force"), + acmeErr: acme.NewErrorISE("error updating order foo for account accID: error updating order: error saving acme order: force"), } }, "fail/db.save-order-error": func(t *testing.T) test { diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index 3fbd42c5..b97fdbd0 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -376,19 +376,20 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + now := time.Now().UTC().Round(time.Second) nor := acmeAPI.NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "dns", Value: "acme.example.com"}, }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute), + NotBefore: now, + NotAfter: now.Add(time.Minute), } norb, err := json.Marshal(nor) assert.FatalError(t, err) ord := acme.Order{ Status: "valid", - ExpiresAt: time.Now(), // "soon" + ExpiresAt: now, // "soon" FinalizeURL: "finalize-url", } ac := &ACMEClient{ @@ -510,7 +511,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) ord := acme.Order{ Status: "valid", - ExpiresAt: time.Now(), // "soon" + ExpiresAt: time.Now().UTC().Round(time.Second), // "soon" FinalizeURL: "finalize-url", } ac := &ACMEClient{ @@ -630,7 +631,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) az := acme.Authorization{ Status: "valid", - ExpiresAt: time.Now(), + ExpiresAt: time.Now().UTC().Round(time.Second), Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, } ac := &ACMEClient{ diff --git a/ca/ca.go b/ca/ca.go index 43cbf0ba..e8eb74f8 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -11,6 +11,7 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" @@ -124,9 +125,14 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } prefix := "acme" - acmeDB, err := acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) - if err != nil { - return nil, errors.Wrap(err, "error configuring ACME DB interface") + var acmeDB acme.DB + if config.DB == nil { + acmeDB = nil + } else { + acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) + if err != nil { + return nil, errors.Wrap(err, "error configuring ACME DB interface") + } } acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ Backdate: *config.AuthorityConfig.Backdate, From b6ebc0fd25972beb512aacf3760a24b66c896fdd Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 00:23:57 -0700 Subject: [PATCH 35/47] more unit tests --- acme/api/account_test.go | 2 +- acme/api/handler.go | 28 ++- acme/api/handler_test.go | 37 ++- acme/api/linker.go | 8 +- acme/api/linker_test.go | 13 +- acme/api/order.go | 25 +- acme/api/order_test.go | 426 +++++++++++++++++++++++++++----- acme/challenge.go | 1 - acme/challenge_test.go | 262 +++++++------------- acme/db/nosql/challenge.go | 3 - acme/db/nosql/challenge_test.go | 13 - 11 files changed, 523 insertions(+), 295 deletions(-) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 918c31c5..7cbe7b7c 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -28,7 +28,7 @@ var ( } ) -func newProv() provisioner.Interface { +func newProv() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", diff --git a/acme/api/handler.go b/acme/api/handler.go index 47c93dfc..31477fca 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -70,14 +71,28 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } + resolver := &net.Resolver{ + // The DNS resolver can be configured for testing purposes with something + // like this: + // + // PreferGo: true, + // Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + // var d net.Dialer + // return d.DialContext(ctx, "udp", "127.0.0.1:5333") + // }, + } return &Handler{ ca: ops.CA, db: ops.DB, backdate: ops.Backdate, linker: NewLinker(ops.DNS, ops.Prefix), validateChallengeOptions: &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, + HTTPGet: client.Get, + LookupTxt: func(name string) ([]string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + return resolver.LookupTXT(ctx, name) + }, TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(dialer, network, addr, config) }, @@ -216,7 +231,8 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // strict enforcement would render these clients broken. For the time being // we'll just ignore the body. - ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), chi.URLParam(r, "authzID")) + azID := chi.URLParam(r, "authzID") + ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) return @@ -236,10 +252,10 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { return } - h.linker.LinkChallenge(ctx, ch) + h.linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, ch.AuthzID), "up")) - w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID)) + w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, azID), "up")) + w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID)) api.JSON(w, ch) } diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 19e5da76..eb8b9f56 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -126,20 +126,18 @@ func TestHandler_GetAuthorization(t *testing.T) { Wildcard: false, Challenges: []*acme.Challenge{ { - Type: "http-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chHTTPID", - ID: "chHTTP01ID", - AuthzID: "authzID", + Type: "http-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chHTTPID", + ID: "chHTTP01ID", }, { - Type: "dns-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chDNSID", - ID: "chDNSID", - AuthzID: "authzID", + Type: "dns-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chDNSID", + ID: "chDNSID", }, }, } @@ -429,12 +427,11 @@ func TestHandler_GetCertificate(t *testing.T) { func ch() acme.Challenge { return acme.Challenge{ - Type: "http-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chID", - ID: "chID", - AuthzID: "authzID", + Type: "http-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chID", + ID: "chID", } } @@ -627,7 +624,6 @@ func TestHandler_GetChallenge(t *testing.T) { assert.Equals(t, azID, "authzID") return &acme.Challenge{ ID: "chID", - AuthzID: "authzID", Status: acme.StatusPending, Type: "http-01", AccountID: "accID", @@ -643,7 +639,6 @@ func TestHandler_GetChallenge(t *testing.T) { }, ch: &acme.Challenge{ ID: "chID", - AuthzID: "authzID", Status: acme.StatusPending, Type: "http-01", AccountID: "accID", @@ -689,7 +684,7 @@ func TestHandler_GetChallenge(t *testing.T) { expB, err := json.Marshal(tc.ch) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, tc.ch.AuthzID)}) + assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")}) assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } diff --git a/acme/api/linker.go b/acme/api/linker.go index b5995852..9459e1bc 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -20,7 +20,7 @@ type Linker interface { LinkOrder(ctx context.Context, o *acme.Order) LinkAccount(ctx context.Context, o *acme.Account) - LinkChallenge(ctx context.Context, o *acme.Challenge) + LinkChallenge(ctx context.Context, o *acme.Challenge, azID string) LinkAuthorization(ctx context.Context, o *acme.Authorization) LinkOrdersByAccountID(ctx context.Context, orders []string) } @@ -164,14 +164,14 @@ func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { } // LinkChallenge sets the ACME links required by an ACME challenge. -func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { - ch.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) +func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { + ch.URL = l.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { for _, ch := range az.Challenges { - l.LinkChallenge(ctx, ch) + l.LinkChallenge(ctx, ch, az.ID) } } diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go index c3075a1a..2252e334 100644 --- a/acme/api/linker_test.go +++ b/acme/api/linker_test.go @@ -214,17 +214,16 @@ func TestLinker_LinkChallenge(t *testing.T) { var tests = map[string]test{ "ok": { ch: &acme.Challenge{ - ID: chID, - AuthzID: azID, + ID: chID, }, validate: func(ch *acme.Challenge) { - assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, ch.AuthzID, ch.ID)) + assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { - l.LinkChallenge(ctx, tc.ch) + l.LinkChallenge(ctx, tc.ch, azID) tc.validate(tc.ch) }) } @@ -252,9 +251,9 @@ func TestLinker_LinkAuthorization(t *testing.T) { az: &acme.Authorization{ ID: azID, Challenges: []*acme.Challenge{ - {ID: chID0, AuthzID: azID}, - {ID: chID1, AuthzID: azID}, - {ID: chID2, AuthzID: azID}, + {ID: chID0}, + {ID: chID1}, + {ID: chID2}, }, }, validate: func(az *acme.Authorization) { diff --git a/acme/api/order.go b/acme/api/order.go index 379c2287..9f557d7f 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -59,6 +59,7 @@ func (f *FinalizeRequest) Validate() error { } var defaultOrderExpiry = time.Hour * 24 +var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { @@ -90,22 +91,23 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } now := clock.Now() - expiry := now.Add(defaultOrderExpiry) // New order. o := &acme.Order{ - AccountID: acc.ID, - ProvisionerID: prov.GetID(), - Status: acme.StatusPending, - ExpiresAt: expiry, - Identifiers: nor.Identifiers, + AccountID: acc.ID, + ProvisionerID: prov.GetID(), + Status: acme.StatusPending, + Identifiers: nor.Identifiers, + ExpiresAt: now.Add(defaultOrderExpiry), + AuthorizationIDs: make([]string, len(nor.Identifiers)), + NotBefore: nor.NotBefore, + NotAfter: nor.NotAfter, } - o.AuthorizationIDs = make([]string, len(o.Identifiers)) for i, identifier := range o.Identifiers { az := &acme.Authorization{ AccountID: acc.ID, Identifier: identifier, - ExpiresAt: expiry, + ExpiresAt: o.ExpiresAt, Status: acme.StatusPending, } if err := h.newAuthorization(ctx, az); err != nil { @@ -121,6 +123,9 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { if o.NotAfter.IsZero() { o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) } + if nor.NotBefore.IsZero() { + o.NotBefore.Add(-defaultOrderBackdate) + } if err := h.db.CreateOrder(ctx, o); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) @@ -155,19 +160,17 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } - az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ AccountID: az.AccountID, - AuthzID: az.ID, Value: az.Identifier.Value, Type: typ, Token: az.Token, Status: acme.StatusPending, } if err := h.db.CreateChallenge(ctx, ch); err != nil { - return err + return acme.WrapErrorISE(err, "error creating challenge") } az.Challenges[i] = ch } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 84136fa3..62652812 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/go-chi/chi" + "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "go.step.sm/crypto/pemutil" @@ -374,33 +375,221 @@ func TestHandler_GetOrder(t *testing.T) { } } -/* -func TestHandler_NewOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC().Add(5 * time.Hour) - naf := nbf.Add(17 * time.Hour) - o := acme.Order{ - ID: "orderID", - ExpiresAt: expiry, - NotBefore: nbf, - NotAfter: naf, - Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, +func TestHandler_newAuthorization(t *testing.T) { + type test struct { + az *acme.Authorization + db acme.DB + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/error-db.CreateChallenge": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }, + } + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return errors.New("force") + }, + }, + az: az, + err: acme.NewErrorISE("error creating challenge: force"), + } + }, + "fail/error-db.CreateAuthorization": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + count := 0 + var ch1, ch2, ch3 **acme.Challenge + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, az.Identifier) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, _az.Wildcard, false) + return errors.New("force") + }, + }, + az: az, + err: acme.NewErrorISE("error creating authorization: force"), + } + }, + "ok/no-wildcard": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + count := 0 + var ch1, ch2, ch3 **acme.Challenge + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, az.Identifier) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, _az.Wildcard, false) + return nil + }, + }, + az: az, + } + }, + "ok/wildcard": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "*.zap.internal", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + var ch1 **acme.Challenge + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + ch1 = &ch + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1}) + assert.Equals(t, _az.Wildcard, true) + return nil + }, + }, + az: az, + } }, - Status: "pending", - AuthorizationURLs: []string{"foo", "bar"}, } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + h := &Handler{db: tc.db} + if err := h.newAuthorization(context.Background(), tc.az); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *acme.Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + + } +} +func TestHandler_NewOrder(t *testing.T) { + // Request with chi context prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/new-order", + url := fmt.Sprintf("%s/acme/%s/order/ordID", baseURL.String(), provName) type test struct { db acme.DB ctx context.Context + nor *NewOrderRequest statusCode int err *acme.Error } @@ -422,33 +611,43 @@ func TestHandler_NewOrder(t *testing.T) { } }, "fail/no-provisioner": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("provisioner expected in request context"), + err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/nil-provisioner": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("provisioner expected in request context"), + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/no-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("paylod does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { @@ -464,8 +663,8 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/malformed-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{} - b, err := json.Marshal(nor) + fr := &NewOrderRequest{} + b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -476,86 +675,179 @@ func TestHandler_NewOrder(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, - "fail/NewOrder-error": func(t *testing.T) test { + "fail/error-h.newAuthorization": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{ + fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "zap.internal"}, }, } - b, err := json.Marshal(nor) + b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + ctx: ctx, + statusCode: 500, db: &acme.MockDB{ - MockCreateOrder: func(ctx context.Context, o *acme.Order) error { - return acme.NewError(acme.ErrorMalformedType, "force") + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.AccountID, "accID") + assert.Equals(t, ch.Type, "dns-01") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return errors.New("force") }, }, - ctx: ctx, - statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "force"), + err: acme.NewErrorISE("error creating challenge: force"), } }, - "ok": func(t *testing.T) test { + "fail/error-db.CreateOrder": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{ + fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "zap.internal"}, }, - NotBefore: nbf, - NotAfter: naf, } - b, err := json.Marshal(nor) + b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) return test{ + ctx: ctx, + statusCode: 500, db: &acme.MockDB{ - MockCreateOrder: func(ctx context.Context, o *acme.Order) error { - o.ID = "orderID" + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") return nil }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, fr.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, fr.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return errors.New("force") + }, }, - ctx: ctx, - statusCode: 201, + err: acme.NewErrorISE("error creating order: force"), } }, - "ok/default-naf-nbf": func(t *testing.T) test { + "ok/no-naf-nbf": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{ + fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "zap.internal"}, }, } - b, err := json.Marshal(nor) + b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) return test{ + ctx: ctx, + statusCode: 201, + nor: fr, db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, fr.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, fr.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, }, - ctx: ctx, - statusCode: 201, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "prefix"), db: tc.db} + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -578,18 +870,30 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(o) - assert.FatalError(t, err) - assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), - provName, o.ID)}) + ro := new(acme.Order) + err = json.Unmarshal(body, ro) + + now := clock.Now() + orderExpiry := now.Add(defaultOrderExpiry) + certExpiry := now.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, ro.ID, "ordID") + assert.Equals(t, ro.Status, acme.StatusPending) + assert.Equals(t, ro.Identifiers, tc.nor.Identifiers) + assert.Equals(t, ro.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.True(t, ro.NotBefore.Add(-time.Minute).Before(now)) + assert.True(t, ro.NotBefore.Add(time.Minute).After(now)) + assert.True(t, ro.NotAfter.Add(-time.Minute).Before(certExpiry)) + assert.True(t, ro.NotAfter.Add(time.Minute).After(certExpiry)) + assert.True(t, ro.ExpiresAt.Add(-time.Minute).Before(orderExpiry)) + assert.True(t, ro.ExpiresAt.Add(time.Minute).After(orderExpiry)) + + assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } -*/ func TestHandler_FinalizeOrder(t *testing.T) { now := clock.Now() diff --git a/acme/challenge.go b/acme/challenge.go index a3514d15..b4f151cd 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -29,7 +29,6 @@ type Challenge struct { URL string `json:"url"` Error *Error `json:"error,omitempty"` ID string `json:"-"` - AuthzID string `json:"-"` AccountID string `json:"-"` Value string `json:"-"` } diff --git a/acme/challenge_test.go b/acme/challenge_test.go index d5b6cc58..caaca8f6 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -39,17 +39,15 @@ func Test_storeError(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/db.UpdateChallenge-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -66,17 +64,15 @@ func Test_storeError(t *testing.T) { }, "fail/db.UpdateChallenge-acme-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -93,17 +89,15 @@ func Test_storeError(t *testing.T) { }, "ok": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -236,12 +230,11 @@ func TestChallenge_Validate(t *testing.T) { }, "fail/http-01": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Status: StatusPending, - AuthzID: "azID", - Type: "http-01", - Token: "token", - Value: "zap.internal", + ID: "chID", + Status: StatusPending, + Type: "http-01", + Token: "token", + Value: "zap.internal", } return test{ @@ -254,7 +247,6 @@ func TestChallenge_Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Status, ch.Status) @@ -274,12 +266,11 @@ func TestChallenge_Validate(t *testing.T) { }, "ok/http-01": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Status: StatusPending, - AuthzID: "azID", - Type: "http-01", - Token: "token", - Value: "zap.internal", + ID: "chID", + Status: StatusPending, + Type: "http-01", + Token: "token", + Value: "zap.internal", } return test{ @@ -292,7 +283,6 @@ func TestChallenge_Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Status, ch.Status) @@ -311,12 +301,11 @@ func TestChallenge_Validate(t *testing.T) { }, "fail/dns-01": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Type: "dns-01", - Status: StatusPending, - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Type: "dns-01", + Status: StatusPending, + Token: "token", + Value: "zap.internal", } return test{ @@ -329,7 +318,6 @@ func TestChallenge_Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Status, ch.Status) @@ -350,12 +338,11 @@ func TestChallenge_Validate(t *testing.T) { }, "ok/dns-01": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Type: "dns-01", - Status: StatusPending, - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Type: "dns-01", + Status: StatusPending, + Token: "token", + Value: "zap.internal", } return test{ @@ -368,7 +355,6 @@ func TestChallenge_Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Status, ch.Status) @@ -388,12 +374,11 @@ func TestChallenge_Validate(t *testing.T) { }, "fail/tls-alpn-01": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Type: "tls-alpn-01", - Status: StatusPending, - Value: "zap.internal", + ID: "chID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", } return test{ ch: ch, @@ -405,7 +390,6 @@ func TestChallenge_Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -426,12 +410,11 @@ func TestChallenge_Validate(t *testing.T) { }, "ok/tls-alpn-01": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Type: "tls-alpn-01", - Status: StatusPending, - Value: "zap.internal", + ID: "chID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -455,7 +438,6 @@ func TestChallenge_Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -517,10 +499,9 @@ func TestHTTP01Validate(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/http-get-error-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ @@ -533,7 +514,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) @@ -550,10 +530,9 @@ func TestHTTP01Validate(t *testing.T) { }, "ok/http-get-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ @@ -566,7 +545,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) @@ -582,10 +560,9 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/http-get->=400-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ @@ -600,7 +577,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) @@ -617,10 +593,9 @@ func TestHTTP01Validate(t *testing.T) { }, "ok/http-get->=400": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ @@ -635,7 +610,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) @@ -651,10 +625,9 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/read-body": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } return test{ @@ -671,10 +644,9 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -695,10 +667,9 @@ func TestHTTP01Validate(t *testing.T) { }, "ok/key-auth-mismatch": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -719,7 +690,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -737,10 +707,9 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -761,7 +730,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -780,10 +748,9 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/update-challenge-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -804,7 +771,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) assert.Equals(t, updch.Status, StatusValid) @@ -823,10 +789,9 @@ func TestHTTP01Validate(t *testing.T) { }, "ok": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -847,7 +812,6 @@ func TestHTTP01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -900,10 +864,9 @@ func TestDNS01Validate(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/lookupTXT-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } return test{ @@ -916,7 +879,6 @@ func TestDNS01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) @@ -934,10 +896,9 @@ func TestDNS01Validate(t *testing.T) { }, "ok/lookupTXT-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } return test{ @@ -950,7 +911,6 @@ func TestDNS01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) @@ -967,10 +927,9 @@ func TestDNS01Validate(t *testing.T) { }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -990,10 +949,9 @@ func TestDNS01Validate(t *testing.T) { }, "fail/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1012,7 +970,6 @@ func TestDNS01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) @@ -1031,10 +988,9 @@ func TestDNS01Validate(t *testing.T) { }, "ok/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1053,7 +1009,6 @@ func TestDNS01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) @@ -1071,10 +1026,9 @@ func TestDNS01Validate(t *testing.T) { }, "fail/update-challenge-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1095,7 +1049,6 @@ func TestDNS01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -1116,10 +1069,9 @@ func TestDNS01Validate(t *testing.T) { }, "ok": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1140,7 +1092,6 @@ func TestDNS01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) @@ -1277,12 +1228,11 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na func TestTLSALPN01Validate(t *testing.T) { makeTLSCh := func() *Challenge { return &Challenge{ - ID: "chID", - AuthzID: "azID", - Token: "token", - Type: "tls-alpn-01", - Status: StatusPending, - Value: "zap.internal", + ID: "chID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", } } type test struct { @@ -1306,7 +1256,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1337,7 +1286,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1369,7 +1317,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1401,7 +1348,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1432,7 +1378,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1469,7 +1414,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1507,7 +1451,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1552,7 +1495,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1596,7 +1538,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1641,7 +1582,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1685,7 +1625,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1752,7 +1691,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1792,7 +1730,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1837,7 +1774,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1881,7 +1817,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1922,7 +1857,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -1962,7 +1896,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -2008,7 +1941,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -2055,7 +1987,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -2102,7 +2033,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -2147,7 +2077,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) @@ -2193,7 +2122,6 @@ func TestTLSALPN01Validate(t *testing.T) { db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) - assert.Equals(t, updch.AuthzID, ch.AuthzID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Status, ch.Status) assert.Equals(t, updch.Type, ch.Type) diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index afcb4600..f3a3cfca 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -13,7 +13,6 @@ import ( type dbChallenge struct { ID string `json:"id"` AccountID string `json:"accountID"` - AuthzID string `json:"authzID"` Type string `json:"type"` Status acme.Status `json:"status"` Token string `json:"token"` @@ -54,7 +53,6 @@ func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error { dbch := &dbChallenge{ ID: ch.ID, - AuthzID: ch.AuthzID, AccountID: ch.AccountID, Value: ch.Value, Status: acme.StatusPending, @@ -77,7 +75,6 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall ch := &acme.Challenge{ ID: dbch.ID, AccountID: dbch.AccountID, - AuthzID: dbch.AuthzID, Type: dbch.Type, Value: dbch.Value, Status: dbch.Status, diff --git a/acme/db/nosql/challenge_test.go b/acme/db/nosql/challenge_test.go index 314fc5f7..b39395e8 100644 --- a/acme/db/nosql/challenge_test.go +++ b/acme/db/nosql/challenge_test.go @@ -66,7 +66,6 @@ func TestDB_getDBChallenge(t *testing.T) { dbc := &dbChallenge{ ID: chID, AccountID: "accountID", - AuthzID: "authzID", Type: "dns-01", Status: acme.StatusPending, Token: "token", @@ -113,7 +112,6 @@ func TestDB_getDBChallenge(t *testing.T) { if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID) - assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID) assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token) @@ -137,7 +135,6 @@ func TestDB_CreateChallenge(t *testing.T) { "fail/cmpAndSwap-error": func(t *testing.T) test { ch := &acme.Challenge{ AccountID: "accountID", - AuthzID: "authzID", Type: "dns-01", Status: acme.StatusPending, Token: "token", @@ -154,7 +151,6 @@ func TestDB_CreateChallenge(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.AccountID, ch.AccountID) - assert.Equals(t, dbc.AuthzID, ch.AuthzID) assert.Equals(t, dbc.Type, ch.Type) assert.Equals(t, dbc.Status, ch.Status) assert.Equals(t, dbc.Token, ch.Token) @@ -174,7 +170,6 @@ func TestDB_CreateChallenge(t *testing.T) { idPtr = &id ch = &acme.Challenge{ AccountID: "accountID", - AuthzID: "authzID", Type: "dns-01", Status: acme.StatusPending, Token: "token", @@ -195,7 +190,6 @@ func TestDB_CreateChallenge(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.AccountID, ch.AccountID) - assert.Equals(t, dbc.AuthzID, ch.AuthzID) assert.Equals(t, dbc.Type, ch.Type) assert.Equals(t, dbc.Status, ch.Status) assert.Equals(t, dbc.Token, ch.Token) @@ -266,7 +260,6 @@ func TestDB_GetChallenge(t *testing.T) { dbc := &dbChallenge{ ID: chID, AccountID: "accountID", - AuthzID: azID, Type: "dns-01", Status: acme.StatusPending, Token: "token", @@ -313,7 +306,6 @@ func TestDB_GetChallenge(t *testing.T) { if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID) - assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID) assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token) @@ -331,7 +323,6 @@ func TestDB_UpdateChallenge(t *testing.T) { dbc := &dbChallenge{ ID: chID, AccountID: "accountID", - AuthzID: "azID", Type: "dns-01", Status: acme.StatusPending, Token: "token", @@ -390,7 +381,6 @@ func TestDB_UpdateChallenge(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbc.ID) assert.Equals(t, dbNew.AccountID, dbc.AccountID) - assert.Equals(t, dbNew.AuthzID, dbc.AuthzID) assert.Equals(t, dbNew.Type, dbc.Type) assert.Equals(t, dbNew.Status, updCh.Status) assert.Equals(t, dbNew.Token, dbc.Token) @@ -408,7 +398,6 @@ func TestDB_UpdateChallenge(t *testing.T) { updCh := &acme.Challenge{ ID: dbc.ID, AccountID: dbc.AccountID, - AuthzID: dbc.AuthzID, Type: dbc.Type, Token: dbc.Token, Value: dbc.Value, @@ -437,7 +426,6 @@ func TestDB_UpdateChallenge(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbc.ID) assert.Equals(t, dbNew.AccountID, dbc.AccountID) - assert.Equals(t, dbNew.AuthzID, dbc.AuthzID) assert.Equals(t, dbNew.Type, dbc.Type) assert.Equals(t, dbNew.Token, dbc.Token) assert.Equals(t, dbNew.Value, dbc.Value) @@ -463,7 +451,6 @@ func TestDB_UpdateChallenge(t *testing.T) { if assert.Nil(t, tc.err) { assert.Equals(t, tc.ch.ID, dbc.ID) assert.Equals(t, tc.ch.AccountID, dbc.AccountID) - assert.Equals(t, tc.ch.AuthzID, dbc.AuthzID) assert.Equals(t, tc.ch.Type, dbc.Type) assert.Equals(t, tc.ch.Token, dbc.Token) assert.Equals(t, tc.ch.Value, dbc.Value) From 18319203631c4ae2367dd03aa7cd37977f537071 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 13:46:51 -0700 Subject: [PATCH 36/47] Finish order unit tests and remove unused mocklinker --- acme/api/handler.go | 19 +- acme/api/linker.go | 78 -------- acme/api/order.go | 4 +- acme/api/order_test.go | 439 ++++++++++++++++++++++++++++++++++++++--- acme/order.go | 30 ++- 5 files changed, 436 insertions(+), 134 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index 31477fca..16deeaf8 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,7 +1,6 @@ package api import ( - "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -71,28 +70,14 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } - resolver := &net.Resolver{ - // The DNS resolver can be configured for testing purposes with something - // like this: - // - // PreferGo: true, - // Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - // var d net.Dialer - // return d.DialContext(ctx, "udp", "127.0.0.1:5333") - // }, - } return &Handler{ ca: ops.CA, db: ops.DB, backdate: ops.Backdate, linker: NewLinker(ops.DNS, ops.Prefix), validateChallengeOptions: &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: func(name string) ([]string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - return resolver.LookupTXT(ctx, name) - }, + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(dialer, network, addr, config) }, diff --git a/acme/api/linker.go b/acme/api/linker.go index 9459e1bc..b6a44dfa 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -181,81 +181,3 @@ func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { orders[i] = l.GetLink(ctx, OrderLinkType, true, id) } } - -// MockLinker implements the Linker interface. Only used for testing. -type MockLinker struct { - MockGetLink func(ctx context.Context, typ LinkType, abs bool, inputs ...string) string - MockGetLinkExplicit func(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string - - MockLinkOrder func(ctx context.Context, o *acme.Order) - MockLinkAccount func(ctx context.Context, o *acme.Account) - MockLinkChallenge func(ctx context.Context, o *acme.Challenge) - MockLinkAuthorization func(ctx context.Context, o *acme.Authorization) - MockLinkOrdersByAccountID func(ctx context.Context, orders []string) - - MockError error - MockRet1 interface{} -} - -// GetLink mock. -func (m *MockLinker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { - if m.MockGetLink != nil { - return m.MockGetLink(ctx, typ, abs, inputs...) - } - - return m.MockRet1.(string) -} - -// GetLinkExplicit mock. -func (m *MockLinker) GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string { - if m.MockGetLinkExplicit != nil { - return m.MockGetLinkExplicit(typ, provName, abs, baseURL, inputs...) - } - - return m.MockRet1.(string) -} - -// LinkOrder mock. -func (m *MockLinker) LinkOrder(ctx context.Context, o *acme.Order) { - if m.MockLinkOrder != nil { - m.MockLinkOrder(ctx, o) - return - } - return -} - -// LinkAccount mock. -func (m *MockLinker) LinkAccount(ctx context.Context, o *acme.Account) { - if m.MockLinkAccount != nil { - m.MockLinkAccount(ctx, o) - return - } - return -} - -// LinkChallenge mock. -func (m *MockLinker) LinkChallenge(ctx context.Context, o *acme.Challenge) { - if m.MockLinkChallenge != nil { - m.MockLinkChallenge(ctx, o) - return - } - return -} - -// LinkAuthorization mock. -func (m *MockLinker) LinkAuthorization(ctx context.Context, o *acme.Authorization) { - if m.MockLinkAuthorization != nil { - m.MockLinkAuthorization(ctx, o) - return - } - return -} - -// LinkOrderAccountsByID mock. -func (m *MockLinker) LinkOrderAccountsByID(ctx context.Context, orders []string) { - if m.MockLinkOrdersByAccountID != nil { - m.MockLinkOrdersByAccountID(ctx, orders) - return - } - return -} diff --git a/acme/api/order.go b/acme/api/order.go index 9f557d7f..e7a913ab 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -123,8 +123,10 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { if o.NotAfter.IsZero() { o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) } + // If request NotBefore was empty then backdate the order.NotBefore (now) + // to avoid timing issues. if nor.NotBefore.IsZero() { - o.NotBefore.Add(-defaultOrderBackdate) + o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } if err := h.db.CreateOrder(ctx, o); err != nil { diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 62652812..597ec018 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -591,6 +591,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx context.Context nor *NewOrderRequest statusCode int + vr func(t *testing.T, o *acme.Order) err *acme.Error } var tests = map[string]func(t *testing.T) test{ @@ -772,14 +773,130 @@ func TestHandler_NewOrder(t *testing.T) { err: acme.NewErrorISE("error creating order: force"), } }, - "ok/no-naf-nbf": func(t *testing.T) test { + "ok/multiple-authz": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - fr := &NewOrderRequest{ + nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, + {Type: "dns", Value: "*.zar.internal"}, }, } - b, err := json.Marshal(fr) + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3, ch4 **acme.Challenge + az1ID, az2ID *string + chCount, azCount = 0, 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch chCount { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Value, "zap.internal") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Value, "zap.internal") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Value, "zap.internal") + ch3 = &ch + case 3: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Value, "zar.internal") + ch4 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + chCount++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + switch azCount { + case 0: + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Wildcard, false) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + case 1: + az.ID = "az2ID" + az2ID = &az.ID + assert.Equals(t, az.Identifier, acme.Identifier{ + Type: "dns", + Value: "zar.internal", + }) + assert.Equals(t, az.Wildcard, true) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch4}) + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + azCount++ + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + expNaf := now.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az2ID", + }) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/default-naf-nbf": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -793,7 +910,7 @@ func TestHandler_NewOrder(t *testing.T) { return test{ ctx: ctx, statusCode: 201, - nor: fr, + nor: nor, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -826,7 +943,7 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) - assert.Equals(t, az.Identifier, fr.Identifiers[0]) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil @@ -836,11 +953,301 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) - assert.Equals(t, o.Identifiers, fr.Identifiers) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + expNaf := now.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/nbf-no-naf": func(t *testing.T) test { + now := clock.Now() + expNbf := now.Add(10 * time.Minute) + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + NotBefore: expNbf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNaf := expNbf.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/naf-no-nbf": func(t *testing.T) test { + now := clock.Now() + expNaf := now.Add(15 * time.Minute) + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + NotAfter: expNaf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/naf-nbf": func(t *testing.T) test { + now := clock.Now() + expNbf := now.Add(5 * time.Minute) + expNaf := now.Add(15 * time.Minute) + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + NotBefore: expNbf, + NotAfter: expNaf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, } }, } @@ -871,22 +1278,10 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { ro := new(acme.Order) - err = json.Unmarshal(body, ro) - - now := clock.Now() - orderExpiry := now.Add(defaultOrderExpiry) - certExpiry := now.Add(prov.DefaultTLSCertDuration()) - - assert.Equals(t, ro.ID, "ordID") - assert.Equals(t, ro.Status, acme.StatusPending) - assert.Equals(t, ro.Identifiers, tc.nor.Identifiers) - assert.Equals(t, ro.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) - assert.True(t, ro.NotBefore.Add(-time.Minute).Before(now)) - assert.True(t, ro.NotBefore.Add(time.Minute).After(now)) - assert.True(t, ro.NotAfter.Add(-time.Minute).Before(certExpiry)) - assert.True(t, ro.NotAfter.Add(time.Minute).After(certExpiry)) - assert.True(t, ro.ExpiresAt.Add(-time.Minute).Before(orderExpiry)) - assert.True(t, ro.ExpiresAt.Add(time.Minute).After(orderExpiry)) + assert.FatalError(t, json.Unmarshal(body, ro)) + if tc.vr != nil { + tc.vr(t, ro) + } assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) diff --git a/acme/order.go b/acme/order.go index fdc250ea..a6112362 100644 --- a/acme/order.go +++ b/acme/order.go @@ -20,22 +20,20 @@ type Identifier struct { // Order contains order metadata for the ACME protocol order type. type Order struct { - ID string `json:"id"` - Status Status `json:"status"` - ExpiresAt time.Time `json:"expires,omitempty"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore,omitempty"` - NotAfter time.Time `json:"notAfter,omitempty"` - Error *Error `json:"error,omitempty"` - AuthorizationIDs []string `json:"-"` - AuthorizationURLs []string `json:"authorizations"` - FinalizeURL string `json:"finalize"` - CertificateID string `json:"-"` - CertificateURL string `json:"certificate,omitempty"` - AccountID string `json:"-"` - ProvisionerID string `json:"-"` - DefaultDuration time.Duration `json:"-"` - Backdate time.Duration `json:"-"` + ID string `json:"id"` + Status Status `json:"status"` + ExpiresAt time.Time `json:"expires,omitempty"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + Error *Error `json:"error,omitempty"` + AuthorizationIDs []string `json:"-"` + AuthorizationURLs []string `json:"authorizations"` + FinalizeURL string `json:"finalize"` + CertificateID string `json:"-"` + CertificateURL string `json:"certificate,omitempty"` + AccountID string `json:"-"` + ProvisionerID string `json:"-"` } // ToLog enables response logging. From 80c8567d9977e2dadb8c035bce0af29ea7aee1de Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 14:54:12 -0700 Subject: [PATCH 37/47] change errnotfound type for getAccount - more generalized NotFound type rather than the nosql one we were using - if the error is not recognized then the logic in create account will break. --- acme/api/middleware.go | 15 ++++++++--- acme/api/middleware_test.go | 2 +- acme/db.go | 8 ++++++ acme/db/nosql/account.go | 4 +-- acme/db/nosql/account_test.go | 50 ++--------------------------------- 5 files changed, 24 insertions(+), 55 deletions(-) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index f2a35c3a..e06e4736 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -3,6 +3,7 @@ package api import ( "context" "crypto/rsa" + "errors" "io/ioutil" "net/http" "net/url" @@ -243,15 +244,21 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } - ctx = context.WithValue(ctx, jwkContextKey, jwk) - kid, err := acme.KeyToID(jwk) + + // Overwrite KeyID with the JWK thumbprint. + jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } - acc, err := h.db.GetAccountByKeyID(ctx, kid) + + // Store the JWK in the context. + ctx = context.WithValue(ctx, jwkContextKey, jwk) + + // Get Account or continue to generate a new one. + acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) switch { - case nosql.IsErrNotFound(err): + case errors.Is(err, acme.ErrNotFound): // For NewAccount requests ... break case err != nil: diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 4f2c4bcb..1c0f3689 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1047,7 +1047,7 @@ func TestHandler_extractJWK(t *testing.T) { db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) - return nil, database.ErrNotFound + return nil, acme.ErrNotFound }, }, next: func(w http.ResponseWriter, r *http.Request) { diff --git a/acme/db.go b/acme/db.go index dcc7846f..d678fef4 100644 --- a/acme/db.go +++ b/acme/db.go @@ -2,8 +2,16 @@ package acme import ( "context" + + "github.com/pkg/errors" ) +// ErrNotFound is an error that should be used by the acme.DB interface to +// indicate that an entity does not exist. For example, in the new-account +// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new +// account. +var ErrNotFound = errors.New("not found") + // DB is the DB interface expected by the step-ca ACME API. type DB interface { CreateAccount(ctx context.Context, acc *Account) error diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 3115e8ab..d7ac9655 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -30,7 +30,7 @@ func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, erro id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return "", acme.NewError(acme.ErrorMalformedType, "account with key-id %s not found", kid) + return "", acme.ErrNotFound } return "", errors.Wrapf(err, "error loading key-account index for key %s", kid) } @@ -42,7 +42,7 @@ func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { data, err := db.db.Get(accountTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return nil, acme.NewError(acme.ErrorMalformedType, "account %s not found", id) + return nil, acme.ErrNotFound } return nil, errors.Wrapf(err, "error loading account %s", id) } diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go index 9f889e64..5ba99a73 100644 --- a/acme/db/nosql/account_test.go +++ b/acme/db/nosql/account_test.go @@ -34,7 +34,7 @@ func TestDB_getDBAccount(t *testing.T) { return nil, nosqldb.ErrNotFound }, }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { @@ -142,7 +142,7 @@ func TestDB_getAccountIDByKeyID(t *testing.T) { return nil, nosqldb.ErrNotFound }, }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), + err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { @@ -221,19 +221,6 @@ func TestDB_GetAccount(t *testing.T) { err: errors.New("error loading account accID: force"), } }, - "fail/forward-acme-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, string(key), accID) - - return nil, nosqldb.ErrNotFound - }, - }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), - } - }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -314,19 +301,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) { err: errors.New("error loading key-account index for key kid: force"), } }, - "fail/db.getAccountIDByKeyID-forward-acme-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, string(bucket), string(accountByKeyIDTable)) - assert.Equals(t, string(key), kid) - - return nil, nosqldb.ErrNotFound - }, - }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), - } - }, "fail/db.GetAccount-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ @@ -347,26 +321,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) { err: errors.New("error loading account accID: force"), } }, - "fail/db.GetAccount-forward-acme-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(accountByKeyIDTable): - assert.Equals(t, string(key), kid) - return []byte(accID), nil - case string(accountTable): - assert.Equals(t, string(key), accID) - return nil, nosqldb.ErrNotFound - default: - assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) - return nil, errors.New("force") - } - }, - }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), - } - }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) From a785131d094c2f1ba61e5ac49d7f559c5f1d4cd2 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 15:15:20 -0700 Subject: [PATCH 38/47] Fix lint issues --- acme/api/account.go | 1 - acme/api/handler.go | 4 - acme/api/handler_test.go | 10 -- acme/api/order_test.go | 2 +- acme/certificate_test.go | 240 ------------------------------------ acme/db/nosql/authz.go | 2 - acme/db/nosql/nonce_test.go | 7 +- acme/db/nosql/nosql_test.go | 16 +-- acme/db/nosql/order.go | 2 - acme/errors.go | 1 - acme/nonce_test.go | 154 ----------------------- 11 files changed, 12 insertions(+), 427 deletions(-) delete mode 100644 acme/certificate_test.go delete mode 100644 acme/nonce_test.go diff --git a/acme/api/account.go b/acme/api/account.go index c7f3d11a..ae39d2f7 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -205,5 +205,4 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { api.JSON(w, orders) logOrdersByAccount(w, orders) - return } diff --git a/acme/api/handler.go b/acme/api/handler.go index 16deeaf8..4fb1a18a 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -142,10 +142,6 @@ func (d *Directory) ToLog() (interface{}, error) { return string(b), nil } -type directory struct { - prefix, dns string -} - // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index eb8b9f56..1ff08600 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -425,16 +425,6 @@ func TestHandler_GetCertificate(t *testing.T) { } } -func ch() acme.Challenge { - return acme.Challenge{ - Type: "http-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chID", - ID: "chID", - } -} - func TestHandler_GetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 597ec018..506f0a0a 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -1570,7 +1570,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) ro := new(acme.Order) - err = json.Unmarshal(body, ro) + assert.FatalError(t, json.Unmarshal(body, ro)) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{url}) diff --git a/acme/certificate_test.go b/acme/certificate_test.go deleted file mode 100644 index adbf8e00..00000000 --- a/acme/certificate_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package acme - -/* -func defaultCertOps() (*CertOptions, error) { - crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt") - if err != nil { - return nil, err - } - inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt") - if err != nil { - return nil, err - } - root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt") - if err != nil { - return nil, err - } - return &CertOptions{ - AccountID: "accID", - OrderID: "ordID", - Leaf: crt, - Intermediates: []*x509.Certificate{inter, root}, - }, nil -} - -func newcert() (*Certificate, error) { - ops, err := defaultCertOps() - if err != nil { - return nil, err - } - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - } - return newCert(mockdb, *ops) -} - -func TestNewCert(t *testing.T) { - type test struct { - db nosql.DB - ops CertOptions - err *Error - id *string - } - tests := map[string]func(t *testing.T) test{ - "fail/cmpAndSwap-error": func(t *testing.T) test { - ops, err := defaultCertOps() - assert.FatalError(t, err) - return test{ - ops: *ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, old, nil) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error storing certificate: force")), - } - }, - "fail/cmpAndSwap-false": func(t *testing.T) test { - ops, err := defaultCertOps() - assert.FatalError(t, err) - return test{ - ops: *ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, old, nil) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")), - } - }, - "ok": func(t *testing.T) test { - ops, err := defaultCertOps() - assert.FatalError(t, err) - var _id string - id := &_id - return test{ - ops: *ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, old, nil) - *id = string(key) - return nil, true, nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if cert, err := newCert(tc.db, tc.ops); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, cert.ID, *tc.id) - assert.Equals(t, cert.AccountID, tc.ops.AccountID) - assert.Equals(t, cert.OrderID, tc.ops.OrderID) - - leaf := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: tc.ops.Leaf.Raw, - }) - var intermediates []byte - for _, cert := range tc.ops.Intermediates { - intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - })...) - } - assert.Equals(t, cert.Leaf, leaf) - assert.Equals(t, cert.Intermediates, intermediates) - - assert.True(t, cert.Created.Before(time.Now().Add(time.Minute))) - assert.True(t, cert.Created.After(time.Now().Add(-time.Minute))) - } - } - }) - } -} - -func TestGetCert(t *testing.T) { - type test struct { - id string - db nosql.DB - cert *certificate - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)), - } - }, - "fail/db-error": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading certificate: force")), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")), - } - }, - "ok": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - b, err := json.Marshal(cert) - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if cert, err := getCert(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.cert.ID, cert.ID) - assert.Equals(t, tc.cert.AccountID, cert.AccountID) - assert.Equals(t, tc.cert.OrderID, cert.OrderID) - assert.Equals(t, tc.cert.Created, cert.Created) - assert.Equals(t, tc.cert.Leaf, cert.Leaf) - assert.Equals(t, tc.cert.Intermediates, cert.Intermediates) - } - } - }) - } -} - -func TestCertificateToACME(t *testing.T) { - cert, err := newcert() - assert.FatalError(t, err) - acmeCert, err := cert.toACME(nil, nil) - assert.FatalError(t, err) - assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert) -} -*/ diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 449a9276..c3283603 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -10,8 +10,6 @@ import ( "github.com/smallstep/nosql" ) -var defaultExpiryDuration = time.Hour * 24 - // dbAuthz is the base authz type that others build from. type dbAuthz struct { ID string `json:"id"` diff --git a/acme/db/nosql/nonce_test.go b/acme/db/nosql/nonce_test.go index 1159ec00..131438f6 100644 --- a/acme/db/nosql/nonce_test.go +++ b/acme/db/nosql/nonce_test.go @@ -16,10 +16,9 @@ import ( func TestDB_CreateNonce(t *testing.T) { type test struct { - db nosql.DB - nonce *acme.Nonce - err error - _id *string + db nosql.DB + err error + _id *string } var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go index 7fd21c50..4396acc8 100644 --- a/acme/db/nosql/nosql_test.go +++ b/acme/db/nosql/nosql_test.go @@ -16,7 +16,7 @@ func TestNew(t *testing.T) { err error } var tests = map[string]test{ - "fail/db.CreateTable-error": test{ + "fail/db.CreateTable-error": { db: &db.MockNoSQLDB{ MCreateTable: func(bucket []byte) error { assert.Equals(t, string(bucket), string(accountTable)) @@ -25,7 +25,7 @@ func TestNew(t *testing.T) { }, err: errors.Errorf("error creating table %s: force", string(accountTable)), }, - "ok": test{ + "ok": { db: &db.MockNoSQLDB{ MCreateTable: func(bucket []byte) error { return nil @@ -60,16 +60,16 @@ func TestDB_save(t *testing.T) { err error } var tests = map[string]test{ - "fail/error-marshaling-new": test{ + "fail/error-marshaling-new": { nu: errorThrower("foo"), err: errors.New("error marshaling acme type: challenge"), }, - "fail/error-marshaling-old": test{ + "fail/error-marshaling-old": { nu: "new", old: errorThrower("foo"), err: errors.New("error marshaling acme type: challenge"), }, - "fail/db.CmpAndSwap-error": test{ + "fail/db.CmpAndSwap-error": { nu: "new", old: "old", db: &db.MockNoSQLDB{ @@ -83,7 +83,7 @@ func TestDB_save(t *testing.T) { }, err: errors.New("error saving acme challenge: force"), }, - "fail/db.CmpAndSwap-false-marshaling-old": test{ + "fail/db.CmpAndSwap-false-marshaling-old": { nu: "new", old: "old", db: &db.MockNoSQLDB{ @@ -97,7 +97,7 @@ func TestDB_save(t *testing.T) { }, err: errors.New("error saving acme challenge; changed since last read"), }, - "ok": test{ + "ok": { nu: "new", old: "old", db: &db.MockNoSQLDB{ @@ -110,7 +110,7 @@ func TestDB_save(t *testing.T) { }, }, }, - "ok/nils": test{ + "ok/nils": { nu: nil, old: nil, db: &db.MockNoSQLDB{ diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index c8fe53e1..8eb578a9 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -120,8 +120,6 @@ func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { return db.save(ctx, old.ID, nu, old, "order", orderTable) } -type orderIDsByAccount struct{} - func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) { ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() diff --git a/acme/errors.go b/acme/errors.go index 54182ec2..c4309599 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -128,7 +128,6 @@ type errorMetadata struct { var ( officialACMEPrefix = "urn:ietf:params:acme:error:" - stepACMEPrefix = "urn:step:acme:error:" errorServerInternalMetadata = errorMetadata{ typ: officialACMEPrefix + ErrorServerInternalType.String(), details: "The server experienced an internal error", diff --git a/acme/nonce_test.go b/acme/nonce_test.go deleted file mode 100644 index 2088d39b..00000000 --- a/acme/nonce_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package acme - -/* -func TestNewNonce(t *testing.T) { - type test struct { - db nosql.DB - err *Error - id *string - } - tests := map[string]func(t *testing.T) test{ - "fail/cmpAndSwap-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, nil) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error storing nonce: force")), - } - }, - "fail/cmpAndSwap-false": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, nil) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")), - } - }, - "ok": func(t *testing.T) test { - var _id string - id := &_id - return test{ - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, nil) - *id = string(key) - return nil, true, nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if n, err := newNonce(tc.db); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, n.ID, *tc.id) - - assert.True(t, n.Created.Before(time.Now().Add(time.Minute))) - assert.True(t, n.Created.After(time.Now().Add(-time.Minute))) - } - } - }) - } -} - -func TestUseNonce(t *testing.T) { - type test struct { - id string - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/update-not-found": func(t *testing.T) test { - id := "foo" - return test{ - db: &db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - assert.Equals(t, tx.Operations[0].Bucket, nonceTable) - assert.Equals(t, tx.Operations[0].Key, []byte(id)) - assert.Equals(t, tx.Operations[0].Cmd, database.Get) - - assert.Equals(t, tx.Operations[1].Bucket, nonceTable) - assert.Equals(t, tx.Operations[1].Key, []byte(id)) - assert.Equals(t, tx.Operations[1].Cmd, database.Delete) - return database.ErrNotFound - }, - }, - id: id, - err: BadNonceErr(nil), - } - }, - "fail/update-error": func(t *testing.T) test { - id := "foo" - return test{ - db: &db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - assert.Equals(t, tx.Operations[0].Bucket, nonceTable) - assert.Equals(t, tx.Operations[0].Key, []byte(id)) - assert.Equals(t, tx.Operations[0].Cmd, database.Get) - - assert.Equals(t, tx.Operations[1].Bucket, nonceTable) - assert.Equals(t, tx.Operations[1].Key, []byte(id)) - assert.Equals(t, tx.Operations[1].Cmd, database.Delete) - return errors.New("force") - }, - }, - id: id, - err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)), - } - }, - "ok": func(t *testing.T) test { - id := "foo" - return test{ - db: &db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - assert.Equals(t, tx.Operations[0].Bucket, nonceTable) - assert.Equals(t, tx.Operations[0].Key, []byte(id)) - assert.Equals(t, tx.Operations[0].Cmd, database.Get) - - assert.Equals(t, tx.Operations[1].Bucket, nonceTable) - assert.Equals(t, tx.Operations[1].Key, []byte(id)) - assert.Equals(t, tx.Operations[1].Cmd, database.Delete) - - return nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := useNonce(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } - }) - } -} -*/ From fd447c5b540ec0636d705f8d7ca0855e95eec7b8 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 16:45:26 -0700 Subject: [PATCH 39/47] Fix small nbf->naf bug in db.CreateOrder - still needs unit test --- acme/db/nosql/order.go | 4 +++- acme/errors.go | 10 ++++++++++ api/errors.go | 8 ++++++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 8eb578a9..305c43d1 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -3,6 +3,7 @@ package nosql import ( "context" "encoding/json" + "fmt" "sync" "time" @@ -91,9 +92,10 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { ExpiresAt: o.ExpiresAt, Identifiers: o.Identifiers, NotBefore: o.NotBefore, - NotAfter: o.NotBefore, + NotAfter: o.NotAfter, AuthorizationIDs: o.AuthorizationIDs, } + fmt.Printf("dbo = %+v\n", dbo) if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { return err } diff --git a/acme/errors.go b/acme/errors.go index c4309599..f4aa17e7 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -1,6 +1,7 @@ package acme import ( + "encoding/json" "fmt" "github.com/pkg/errors" @@ -337,3 +338,12 @@ func (e *Error) Cause() error { } return e.Err } + +// ToLog implements the EnableLogger interface. +func (e *Error) ToLog() (interface{}, error) { + b, err := json.Marshal(e) + if err != nil { + return nil, WrapErrorISE(err, "error marshaling acme.Error for logging") + } + return string(b), nil +} diff --git a/api/errors.go b/api/errors.go index 460192fc..67c9ba87 100644 --- a/api/errors.go +++ b/api/errors.go @@ -33,11 +33,15 @@ func WriteError(w http.ResponseWriter, err error) { // Write errors in the response writer if rl, ok := w.(logging.ResponseLogger); ok { + logErr := err + if u, ok := err.(*acme.Error); ok { + logErr = u.Err + } rl.WithFields(map[string]interface{}{ - "error": err, + "error": logErr, }) if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.(errs.StackTracer); ok { + if e, ok := logErr.(errs.StackTracer); ok { rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e), }) From bdace1e53f850965ab9b7a7afb7194cb2fbae404 Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 19:40:18 -0700 Subject: [PATCH 40/47] Add failure scenarios to db.CreateOrder unit tests --- acme/db/nosql/order.go | 2 -- acme/db/nosql/order_test.go | 14 ++++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 305c43d1..5513b0b6 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -3,7 +3,6 @@ package nosql import ( "context" "encoding/json" - "fmt" "sync" "time" @@ -95,7 +94,6 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { NotAfter: o.NotAfter, AuthorizationIDs: o.AuthorizationIDs, } - fmt.Printf("dbo = %+v\n", dbo) if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { return err } diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go index 746066a2..7248700f 100644 --- a/acme/db/nosql/order_test.go +++ b/acme/db/nosql/order_test.go @@ -385,6 +385,8 @@ func TestDB_UpdateOrder(t *testing.T) { func TestDB_CreateOrder(t *testing.T) { now := clock.Now() + nbf := now.Add(5 * time.Minute) + naf := now.Add(15 * time.Minute) type test struct { db nosql.DB o *acme.Order @@ -399,8 +401,8 @@ func TestDB_CreateOrder(t *testing.T) { CertificateID: "certID", Status: acme.StatusValid, ExpiresAt: now, - NotBefore: now, - NotAfter: now, + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, @@ -443,8 +445,8 @@ func TestDB_CreateOrder(t *testing.T) { CertificateID: "certID", Status: acme.StatusValid, ExpiresAt: now, - NotBefore: now, - NotAfter: now, + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, @@ -496,8 +498,8 @@ func TestDB_CreateOrder(t *testing.T) { ProvisionerID: "provID", Status: acme.StatusValid, ExpiresAt: now, - NotBefore: now, - NotAfter: now, + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, From 6b8585c702a7578725718d599f4a30bf17e942f0 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 29 Mar 2021 12:04:14 -0700 Subject: [PATCH 41/47] PR review fixes / updates --- acme/account.go | 4 +- acme/api/handler.go | 11 +- acme/api/handler_test.go | 15 +- acme/authority_test.go | 1721 ----------------------------------- acme/authorization.go | 8 +- acme/challenge.go | 19 +- acme/common.go | 4 +- acme/db/nosql/account.go | 4 +- acme/db/nosql/authz.go | 4 +- acme/db/nosql/nonce.go | 41 +- acme/db/nosql/nonce_test.go | 134 +-- acme/db/nosql/order.go | 10 +- acme/order.go | 4 +- 13 files changed, 113 insertions(+), 1866 deletions(-) delete mode 100644 acme/authority_test.go diff --git a/acme/account.go b/acme/account.go index 3b6bafed..197a3400 100644 --- a/acme/account.go +++ b/acme/account.go @@ -11,11 +11,11 @@ import ( // Account is a subset of the internal account type containing only those // attributes required for responses in the ACME protocol. type Account struct { + ID string `json:"-"` + Key *jose.JSONWebKey `json:"-"` Contact []string `json:"contact,omitempty"` Status Status `json:"status"` OrdersURL string `json:"orders"` - ID string `json:"-"` - Key *jose.JSONWebKey `json:"-"` } // ToLog enables response logging. diff --git a/acme/api/handler.go b/acme/api/handler.go index 4fb1a18a..17565998 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -21,14 +21,14 @@ func link(url, typ string) string { } // Clock that returns time in UTC rounded to seconds. -type Clock int +type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { return time.Now().UTC().Round(time.Second) } -var clock = new(Clock) +var clock Clock type payloadInfo struct { value []byte @@ -65,7 +65,7 @@ type HandlerOptions struct { // NewHandler returns a new ACME API handler. func NewHandler(ops HandlerOptions) api.RouterHandler { client := http.Client{ - Timeout: time.Duration(30 * time.Second), + Timeout: 30 * time.Second, } dialer := &net.Dialer{ Timeout: 30 * time.Second, @@ -89,8 +89,8 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { func (h *Handler) Route(r api.Router) { getLink := h.linker.GetLinkExplicit // Standard ACME API - r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) + r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) @@ -218,6 +218,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) return } + ch.AuthorizationID = azID if acc.ID != ch.AccountID { api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 1ff08600..5501479d 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -582,6 +582,7 @@ func TestHandler_GetChallenge(t *testing.T) { assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Type, "http-01") assert.Equals(t, ch.AccountID, "accID") + assert.Equals(t, ch.AuthorizationID, "authzID") assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) return acme.NewErrorISE("force") }, @@ -623,17 +624,19 @@ func TestHandler_GetChallenge(t *testing.T) { assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Type, "http-01") assert.Equals(t, ch.AccountID, "accID") + assert.Equals(t, ch.AuthorizationID, "authzID") assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) return nil }, }, ch: &acme.Challenge{ - ID: "chID", - Status: acme.StatusPending, - Type: "http-01", - AccountID: "accID", - URL: url, - Error: acme.NewError(acme.ErrorConnectionType, "force"), + ID: "chID", + Status: acme.StatusPending, + AuthorizationID: "authzID", + Type: "http-01", + AccountID: "accID", + URL: url, + Error: acme.NewError(acme.ErrorConnectionType, "force"), }, vco: &acme.ValidateChallengeOptions{ HTTPGet: func(string) (*http.Response, error) { diff --git a/acme/authority_test.go b/acme/authority_test.go deleted file mode 100644 index 0e8de984..00000000 --- a/acme/authority_test.go +++ /dev/null @@ -1,1721 +0,0 @@ -package acme - -/* -func TestAuthorityGetLink(t *testing.T) { - auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - type test struct { - auth *Authority - typ Link - abs bool - inputs []string - res string - } - tests := map[string]func(t *testing.T) test{ - "ok/new-account/abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: NewAccountLink, - abs: true, - res: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - } - }, - "ok/new-account/no-abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: NewAccountLink, - abs: false, - res: fmt.Sprintf("/%s/new-account", provName), - } - }, - "ok/order/abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: OrderLink, - abs: true, - inputs: []string{"foo"}, - res: fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName), - } - }, - "ok/order/no-abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: OrderLink, - abs: false, - inputs: []string{"foo"}, - res: fmt.Sprintf("/%s/order/foo", provName), - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - link := tc.auth.GetLink(ctx, tc.typ, tc.abs, tc.inputs...) - assert.Equals(t, tc.res, link) - }) - } -} - -func TestAuthorityGetDirectory(t *testing.T) { - auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - - prov := newProv() - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - ctx context.Context - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/empty-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - } - }, - "ok/no-baseURL": func(t *testing.T) test { - return test{ - ctx: context.WithValue(context.Background(), ProvisionerContextKey, prov), - } - }, - "ok/baseURL": func(t *testing.T) test { - return test{ - ctx: ctx, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if dir, err := auth.GetDirectory(tc.ctx); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - bu := BaseURLFromContext(tc.ctx) - if bu == nil { - bu = &url.URL{Scheme: "https", Host: "ca.smallstep.com"} - } - - var provName string - prov, err := ProvisionerFromContext(tc.ctx) - if err != nil { - provName = "" - } else { - provName = url.PathEscape(prov.GetName()) - } - - assert.Equals(t, dir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", bu.String(), provName)) - assert.Equals(t, dir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", bu.String(), provName)) - assert.Equals(t, dir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", bu.String(), provName)) - assert.Equals(t, dir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", bu.String(), provName)) - assert.Equals(t, dir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", bu.String(), provName)) - } - } - }) - } -} - -func TestAuthorityNewNonce(t *testing.T) { - type test struct { - auth *Authority - res *string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/newNonce-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - res: nil, - err: ServerInternalErr(errors.New("error storing nonce: force")), - } - }, - "ok": func(t *testing.T) test { - var _res string - res := &_res - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - *res = string(key) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - res: res, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if nonce, err := tc.auth.NewNonce(); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, nonce, *tc.res) - } - } - }) - } -} - -func TestAuthorityUseNonce(t *testing.T) { - type test struct { - auth *Authority - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/newNonce-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - return errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - err: ServerInternalErr(errors.New("error deleting nonce foo: force")), - } - }, - "ok": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - return nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.auth.UseNonce("foo"); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAuthorityNewAccount(t *testing.T) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - ops := AccountOptions{ - Key: jwk, Contact: []string{"foo", "bar"}, - } - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - ops AccountOptions - err *Error - acc **Account - } - tests := map[string]func(t *testing.T) test{ - "fail/newAccount-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: ops, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "ok": func(t *testing.T) test { - var ( - _acmeacc = &Account{} - acmeacc = &_acmeacc - count = 0 - dir = newDirectory("ca.smallstep.com", "acme") - ) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - var acc *account - assert.FatalError(t, json.Unmarshal(newval, &acc)) - *acmeacc, err = acc.toACME(ctx, nil, dir) - return nil, true, nil - } - count++ - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: ops, - acc: acmeacc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.NewAccount(ctx, tc.ops); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - expb, err := json.Marshal(*tc.acc) - assert.FatalError(t, err) - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetAccount(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id string - err *Error - acc *account - } - tests := map[string]func(t *testing.T) test{ - "fail/getAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - acc: acc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.GetAccount(ctx, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetAccountByKey(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - jwk *jose.JSONWebKey - err *Error - acc *account - } - tests := map[string]func(t *testing.T) test{ - "fail/generate-thumbprint-error": func(t *testing.T) test { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - jwk.Key = "foo" - auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - jwk: jwk, - err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "fail/getAccount-error": func(t *testing.T) test { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - kid, err := keyToID(jwk) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - jwk: jwk, - err: ServerInternalErr(errors.New("error loading key-account index: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch { - case count == 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - ret = []byte(acc.ID) - case count == 1: - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - ret = b - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - jwk: acc.Key, - acc: acc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.GetAccountByKey(ctx, tc.jwk); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetOrder(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id, accID string - err *Error - o *order - } - tests := map[string]func(t *testing.T) test{ - "fail/getOrder-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading order foo: force")), - } - }, - "fail/order-not-owned-by-account": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own order")), - } - }, - "fail/updateStatus-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - i := 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch { - case i == 0: - i++ - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - default: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(o.Authorizations[0])) - return nil, ServerInternalErr(errors.New("force")) - } - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", o.Authorizations[0])), - } - }, - "ok": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = "valid" - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - o: o, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeO, err := tc.auth.GetOrder(ctx, tc.accID, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeO) - assert.FatalError(t, err) - - acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetCertificate(t *testing.T) { - type test struct { - auth *Authority - id, accID string - err *Error - cert *certificate - } - tests := map[string]func(t *testing.T) test{ - "fail/getCertificate-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading certificate: force")), - } - }, - "fail/certificate-not-owned-by-account": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - b, err := json.Marshal(cert) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: cert.ID, - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own certificate")), - } - }, - "ok": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - b, err := json.Marshal(cert) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: cert.ID, - accID: cert.AccountID, - cert: cert, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeCert, err := tc.auth.GetCertificate(tc.accID, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeCert) - assert.FatalError(t, err) - - acmeExp, err := tc.cert.toACME(nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetAuthz(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id, accID string - err *Error - acmeAz *Authz - } - tests := map[string]func(t *testing.T) test{ - "fail/getAuthz-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", id)), - } - }, - "fail/authz-not-owned-by-account": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: az.getID(), - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own authz")), - } - }, - "fail/update-status-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - count := 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - ret = b - case 1: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(az.getChallenges()[0])) - return nil, errors.New("force") - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: az.getID(), - accID: az.getAccountID(), - err: ServerInternalErr(errors.New("error updating authz status: error loading challenge")), - } - }, - "ok": func(t *testing.T) test { - var ch1B, ch2B, ch3B = &[]byte{}, &[]byte{}, &[]byte{} - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - *ch1B = newval - case 1: - *ch2B = newval - case 2: - *ch3B = newval - } - count++ - return nil, true, nil - }, - } - az, err := newAuthz(mockdb, "1234", Identifier{ - Type: "dns", Value: "acme.example.com", - }) - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusValid - b, err := json.Marshal(az) - assert.FatalError(t, err) - - ch1, err := unmarshalChallenge(*ch1B) - assert.FatalError(t, err) - ch2, err := unmarshalChallenge(*ch2B) - assert.FatalError(t, err) - ch3, err := unmarshalChallenge(*ch3B) - assert.FatalError(t, err) - count = 0 - mockdb = &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch1.getID())) - ret = *ch1B - case 1: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch2.getID())) - ret = *ch2B - case 2: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch3.getID())) - ret = *ch3B - } - count++ - return ret, nil - }, - } - acmeAz, err := az.toACME(ctx, mockdb, newDirectory("ca.smallstep.com", "acme")) - assert.FatalError(t, err) - - count = 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - ret = b - case 1: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch1.getID())) - ret = *ch1B - case 2: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch2.getID())) - ret = *ch2B - case 3: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch3.getID())) - ret = *ch3B - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: az.getID(), - accID: az.getAccountID(), - acmeAz: acmeAz, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAz, err := tc.auth.GetAuthz(ctx, tc.accID, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAz) - assert.FatalError(t, err) - - expb, err := json.Marshal(tc.acmeAz) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityNewOrder(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - ops OrderOptions - ctx context.Context - err *Error - o **Order - } - tests := map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: defaultOrderOps(), - ctx: context.Background(), - err: ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/newOrder-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: defaultOrderOps(), - ctx: ctx, - err: ServerInternalErr(errors.New("error creating order: error creating http challenge: error saving acme challenge: force")), - } - }, - "ok": func(t *testing.T) test { - var ( - _acmeO = &Order{} - acmeO = &_acmeO - count = 0 - dir = newDirectory("ca.smallstep.com", "acme") - err error - _accID string - accID = &_accID - ) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - assert.Equals(t, bucket, challengeTable) - case 1: - assert.Equals(t, bucket, challengeTable) - case 2: - assert.Equals(t, bucket, challengeTable) - case 3: - assert.Equals(t, bucket, authzTable) - case 4: - assert.Equals(t, bucket, challengeTable) - case 5: - assert.Equals(t, bucket, challengeTable) - case 6: - assert.Equals(t, bucket, challengeTable) - case 7: - assert.Equals(t, bucket, authzTable) - case 8: - assert.Equals(t, bucket, orderTable) - var o order - assert.FatalError(t, json.Unmarshal(newval, &o)) - *acmeO, err = o.toACME(ctx, nil, dir) - assert.FatalError(t, err) - *accID = o.AccountID - case 9: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, string(key), *accID) - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: defaultOrderOps(), - ctx: ctx, - o: acmeO, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeO, err := tc.auth.NewOrder(tc.ctx, tc.ops); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeO) - assert.FatalError(t, err) - expb, err := json.Marshal(*tc.o) - assert.FatalError(t, err) - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetOrdersByAccount(t *testing.T) { - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - type test struct { - auth *Authority - id string - err *Error - res []string - } - tests := map[string]func(t *testing.T) test{ - "fail/getOrderIDsByAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), - } - }, - "fail/getOrder-error": func(t *testing.T) test { - var ( - id = "zap" - oids = []string{"foo", "bar"} - count = 0 - err error - ) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(id)) - ret, err = json.Marshal(oids) - assert.FatalError(t, err) - case 1: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(oids[0])) - return nil, errors.New("force") - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading order foo for account zap: error loading order foo: force")), - } - }, - "ok": func(t *testing.T) test { - accID := "zap" - - foo, err := newO() - assert.FatalError(t, err) - bfoo, err := json.Marshal(foo) - assert.FatalError(t, err) - - bar, err := newO() - assert.FatalError(t, err) - bar.Status = StatusInvalid - bbar, err := json.Marshal(bar) - assert.FatalError(t, err) - - zap, err := newO() - assert.FatalError(t, err) - bzap, err := json.Marshal(zap) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(orderTable): - dbGetOrder++ - switch dbGetOrder { - case 1: - return bfoo, nil - case 2: - return bbar, nil - case 3: - return bzap, nil - } - case string(ordersByAccountIDTable): - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - ret, err := json.Marshal([]string{foo.ID, bar.ID, zap.ID}) - assert.FatalError(t, err) - return ret, nil - case string(challengeTable): - return bch, nil - case string(authzTable): - return baz, nil - } - return nil, errors.Errorf("should not be query db table %s", bucket) - }, - MCmpAndSwap: func(bucket, key, old, newVal []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, string(key), accID) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: accID, - res: []string{ - fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID), - fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, zap.ID), - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if orderLinks, err := tc.auth.GetOrdersByAccount(ctx, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res, orderLinks) - } - } - }) - } -} - -func TestAuthorityFinalizeOrder(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id, accID string - ctx context.Context - err *Error - o *order - } - tests := map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: "foo", - ctx: context.Background(), - err: ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/getOrder-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - ctx: ctx, - err: ServerInternalErr(errors.New("error loading order foo: force")), - } - }, - "fail/order-not-owned-by-account": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: "foo", - ctx: ctx, - err: UnauthorizedErr(errors.New("account does not own order")), - } - }, - "fail/finalize-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Expires = time.Now().Add(-time.Minute) - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - ctx: ctx, - err: ServerInternalErr(errors.New("error finalizing order: error storing order: force")), - } - }, - "ok": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - o.Certificate = "certID" - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - ctx: ctx, - o: o, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeO, err := tc.auth.FinalizeOrder(tc.ctx, tc.accID, tc.id, nil); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeO) - assert.FatalError(t, err) - - acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityValidateChallenge(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - - type test struct { - auth *Authority - id, accID string - err *Error - ch challenge - jwk *jose.JSONWebKey - server *httptest.Server - } - tests := map[string]func(t *testing.T) test{ - "fail/getChallenge-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", id)), - } - }, - "fail/challenge-not-owned-by-account": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own challenge")), - } - }, - "fail/validate-error": func(t *testing.T) test { - keyauth := "temp" - keyauthp := &keyauth - // Create test server that returns challenge auth - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%s\r\n", *keyauthp) - })) - t.Cleanup(func() { ts.Close() }) - - ch, err := newHTTPChWithServer(strings.TrimPrefix(ts.URL, "http://")) - assert.FatalError(t, err) - - jwk, _, err := jose.GenerateDefaultKeyPair([]byte("pass")) - assert.FatalError(t, err) - - thumbprint, err := jwk.Thumbprint(crypto.SHA256) - assert.FatalError(t, err) - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - *keyauthp = fmt.Sprintf("%s.%s", ch.getToken(), encPrint) - - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: ch.getAccountID(), - jwk: jwk, - server: ts, - err: ServerInternalErr(errors.New("error attempting challenge validation: error saving acme challenge: force")), - } - }, - "ok/already-valid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - _ch.baseChallenge.Validated = clock.Now() - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: ch.getAccountID(), - ch: ch, - } - }, - "ok": func(t *testing.T) test { - keyauth := "temp" - keyauthp := &keyauth - // Create test server that returns challenge auth - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%s\r\n", *keyauthp) - })) - t.Cleanup(func() { ts.Close() }) - - ch, err := newHTTPChWithServer(strings.TrimPrefix(ts.URL, "http://")) - assert.FatalError(t, err) - - jwk, _, err := jose.GenerateDefaultKeyPair([]byte("pass")) - assert.FatalError(t, err) - - thumbprint, err := jwk.Thumbprint(crypto.SHA256) - assert.FatalError(t, err) - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - *keyauthp = fmt.Sprintf("%s.%s", ch.getToken(), encPrint) - - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: ch.getAccountID(), - jwk: jwk, - server: ts, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeCh, err := tc.auth.ValidateChallenge(ctx, tc.accID, tc.id, tc.jwk); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeCh) - assert.FatalError(t, err) - - if tc.ch != nil { - acmeExp, err := tc.ch.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - } - }) - } -} - -func TestAuthorityUpdateAccount(t *testing.T) { - contact := []string{"baz", "zap"} - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id string - contact []string - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - contact: contact, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), - } - }, - "fail/update-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - contact: contact, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Contact = contact - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - contact: contact, - acc: clone, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.UpdateAccount(ctx, tc.id, tc.contact); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityDeactivateAccount(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id string - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), - } - }, - "fail/deactivate-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Status = StatusDeactivated - clone.Deactivated = clock.Now() - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - acc: clone, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.DeactivateAccount(ctx, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} -*/ diff --git a/acme/authorization.go b/acme/authorization.go index 4d5c42c8..d2df5ea5 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -8,15 +8,15 @@ import ( // Authorization representst an ACME Authorization. type Authorization struct { + ID string `json:"-"` + AccountID string `json:"-"` + Token string `json:"-"` Identifier Identifier `json:"identifier"` Status Status `json:"status"` - ExpiresAt time.Time `json:"expires"` Challenges []*Challenge `json:"challenges"` Wildcard bool `json:"wildcard"` + ExpiresAt time.Time `json:"expires"` Error *Error `json:"error,omitempty"` - ID string `json:"-"` - AccountID string `json:"-"` - Token string `json:"-"` } // ToLog enables response logging. diff --git a/acme/challenge.go b/acme/challenge.go index b4f151cd..94c74b74 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -22,15 +22,16 @@ import ( // Challenge represents an ACME response Challenge type. type Challenge struct { - Type string `json:"type"` - Status Status `json:"status"` - Token string `json:"token"` - ValidatedAt string `json:"validated,omitempty"` - URL string `json:"url"` - Error *Error `json:"error,omitempty"` - ID string `json:"-"` - AccountID string `json:"-"` - Value string `json:"-"` + ID string `json:"-"` + AccountID string `json:"-"` + AuthorizationID string `json:"-"` + Value string `json:"-"` + Type string `json:"type"` + Status Status `json:"status"` + Token string `json:"token"` + ValidatedAt string `json:"validated,omitempty"` + URL string `json:"url"` + Error *Error `json:"error,omitempty"` } // ToLog enables response logging. diff --git a/acme/common.go b/acme/common.go index f7fd7141..05c909eb 100644 --- a/acme/common.go +++ b/acme/common.go @@ -15,14 +15,14 @@ type CertificateAuthority interface { } // Clock that returns time in UTC rounded to seconds. -type Clock int +type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { return time.Now().UTC().Round(time.Second) } -var clock = new(Clock) +var clock Clock // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index d7ac9655..1c3bec5d 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -14,11 +14,11 @@ import ( // dbAccount represents an ACME account. type dbAccount struct { ID string `json:"id"` - CreatedAt time.Time `json:"createdAt"` - DeactivatedAt time.Time `json:"deactivatedAt"` Key *jose.JSONWebKey `json:"key"` Contact []string `json:"contact,omitempty"` Status acme.Status `json:"status"` + CreatedAt time.Time `json:"createdAt"` + DeactivatedAt time.Time `json:"deactivatedAt"` } func (dba *dbAccount) clone() *dbAccount { diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index c3283603..6decbe4f 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -16,12 +16,12 @@ type dbAuthz struct { AccountID string `json:"accountID"` Identifier acme.Identifier `json:"identifier"` Status acme.Status `json:"status"` - ExpiresAt time.Time `json:"expiresAt"` + Token string `json:"token"` ChallengeIDs []string `json:"challengeIDs"` Wildcard bool `json:"wildcard"` CreatedAt time.Time `json:"createdAt"` + ExpiresAt time.Time `json:"expiresAt"` Error *acme.Error `json:"error"` - Token string `json:"token"` } func (ba *dbAuthz) clone() *dbAuthz { diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index 4f587ae0..6344f0c6 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -3,12 +3,12 @@ package nosql import ( "context" "encoding/base64" - "encoding/json" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" ) // dbNonce contains nonce metadata used in the ACME protocol. @@ -45,24 +45,27 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { // DeleteNonce verifies that the nonce is valid (by checking if it exists), // and if so, consumes the nonce resource by deleting it from the database. func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error { - id := string(nonce) - b, err := db.db.Get(nonceTable, []byte(nonce)) - if nosql.IsErrNotFound(err) { - return errors.Wrapf(err, "nonce %s not found", id) - } else if err != nil { - return errors.Wrapf(err, "error loading nonce %s", id) - } + err := db.db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Get, + }, + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Delete, + }, + }, + }) - dbn := new(dbNonce) - if err := json.Unmarshal(b, dbn); err != nil { - return errors.Wrapf(err, "error unmarshaling nonce %s", string(nonce)) - } - if !dbn.DeletedAt.IsZero() { - return acme.NewError(acme.ErrorBadNonceType, "nonce %s already deleted", id) + switch { + case nosql.IsErrNotFound(err): + return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce)) + case err != nil: + return errors.Wrapf(err, "error deleting nonce %s", string(nonce)) + default: + return nil } - - nu := dbn.clone() - nu.DeletedAt = clock.Now() - - return db.save(ctx, id, nu, dbn, "nonce", nonceTable) } diff --git a/acme/db/nosql/nonce_test.go b/acme/db/nosql/nonce_test.go index 131438f6..05d73d52 100644 --- a/acme/db/nosql/nonce_test.go +++ b/acme/db/nosql/nonce_test.go @@ -11,7 +11,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" - nosqldb "github.com/smallstep/nosql/database" + "github.com/smallstep/nosql/database" ) func TestDB_CreateNonce(t *testing.T) { @@ -85,108 +85,57 @@ func TestDB_DeleteNonce(t *testing.T) { nonceID := "nonceID" type test struct { - db nosql.DB - err error + db nosql.DB + err error + acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, string(key), nonceID) - - return nil, nosqldb.ErrNotFound - }, - }, - err: errors.New("nonce nonceID not found"), - } - }, - "fail/db.Get-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, string(key), nonceID) - - return nil, errors.Errorf("force") - }, - }, - err: errors.New("error loading nonce nonceID: force"), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, string(key), nonceID) - - a := []string{"foo", "bar", "baz"} - b, err := json.Marshal(a) - assert.FatalError(t, err) - - return b, nil + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return database.ErrNotFound }, }, - err: errors.New("error unmarshaling nonce nonceID"), + acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID), } }, - "fail/already-used": func(t *testing.T) test { + "fail/db.Update-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, string(key), nonceID) - - nonce := dbNonce{ - ID: nonceID, - CreatedAt: clock.Now().Add(-5 * time.Minute), - DeletedAt: clock.Now(), - } - b, err := json.Marshal(nonce) - assert.FatalError(t, err) - - return b, nil + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return errors.New("force") }, }, - err: acme.NewError(acme.ErrorBadNonceType, "nonce already deleted"), + err: errors.New("error deleting nonce nonceID: force"), } }, "ok": func(t *testing.T) test { - nonce := dbNonce{ - ID: nonceID, - CreatedAt: clock.Now().Add(-5 * time.Minute), - } - b, err := json.Marshal(nonce) - assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, string(key), nonceID) - - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, b) - - dbo := new(dbNonce) - assert.FatalError(t, json.Unmarshal(old, dbo)) - assert.Equals(t, dbo.ID, string(key)) - assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbo.CreatedAt)) - assert.True(t, clock.Now().Add(-4*time.Minute).After(dbo.CreatedAt)) - assert.True(t, dbo.DeletedAt.IsZero()) - - dbn := new(dbNonce) - assert.FatalError(t, json.Unmarshal(nu, dbn)) - assert.Equals(t, dbn.ID, string(key)) - assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbn.CreatedAt)) - assert.True(t, clock.Now().Add(-4*time.Minute).After(dbn.CreatedAt)) - assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.DeletedAt)) - assert.True(t, clock.Now().Add(time.Minute).After(dbn.DeletedAt)) - return nil, true, nil + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return nil }, }, } @@ -197,8 +146,19 @@ func TestDB_DeleteNonce(t *testing.T) { t.Run(name, func(t *testing.T) { db := DB{db: tc.db} if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { - if assert.NotNil(t, tc.err) { - assert.HasPrefix(t, err.Error(), tc.err.Error()) + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } } } else { assert.Nil(t, tc.err) diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 5513b0b6..ba3934af 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -18,15 +18,15 @@ type dbOrder struct { ID string `json:"id"` AccountID string `json:"accountID"` ProvisionerID string `json:"provisionerID"` - CreatedAt time.Time `json:"createdAt"` - ExpiresAt time.Time `json:"expiresAt,omitempty"` - Status acme.Status `json:"status"` Identifiers []acme.Identifier `json:"identifiers"` + AuthorizationIDs []string `json:"authorizationIDs"` + Status acme.Status `json:"status"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` - Error *acme.Error `json:"error,omitempty"` - AuthorizationIDs []string `json:"authorizationIDs"` + CreatedAt time.Time `json:"createdAt"` + ExpiresAt time.Time `json:"expiresAt,omitempty"` CertificateID string `json:"certificate,omitempty"` + Error *acme.Error `json:"error,omitempty"` } func (a *dbOrder) clone() *dbOrder { diff --git a/acme/order.go b/acme/order.go index a6112362..7405906d 100644 --- a/acme/order.go +++ b/acme/order.go @@ -21,6 +21,8 @@ type Identifier struct { // Order contains order metadata for the ACME protocol order type. type Order struct { ID string `json:"id"` + AccountID string `json:"-"` + ProvisionerID string `json:"-"` Status Status `json:"status"` ExpiresAt time.Time `json:"expires,omitempty"` Identifiers []Identifier `json:"identifiers"` @@ -32,8 +34,6 @@ type Order struct { FinalizeURL string `json:"finalize"` CertificateID string `json:"-"` CertificateURL string `json:"certificate,omitempty"` - AccountID string `json:"-"` - ProvisionerID string `json:"-"` } // ToLog enables response logging. From 440678cb62e4c0b823bf23b47dea371aa5139922 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 29 Mar 2021 22:58:26 -0700 Subject: [PATCH 42/47] Add markInvalid arg to storeError for invalidating challenge --- acme/api/handler.go | 2 +- acme/challenge.go | 50 +++++---- acme/challenge_test.go | 239 +++++++++++++++++++++++++++-------------- acme/common.go | 2 +- acme/order.go | 6 +- 5 files changed, 190 insertions(+), 109 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index 17565998..e557f33b 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -25,7 +25,7 @@ type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { - return time.Now().UTC().Round(time.Second) + return time.Now().UTC().Truncate(time.Second) } var clock Clock diff --git a/acme/challenge.go b/acme/challenge.go index 94c74b74..a47fc7df 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "strings" "time" @@ -49,7 +50,7 @@ func (ch *Challenge) ToLog() (interface{}, error) { // updated. func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { // If already valid or invalid then return without performing validation. - if ch.Status == StatusValid || ch.Status == StatusInvalid { + if ch.Status != StatusPending { return nil } switch ch.Type { @@ -65,32 +66,32 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, } func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { - url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token) + url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - resp, err := vo.HTTPGet(url) + resp, err := vo.HTTPGet(url.String()) if err != nil { - return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err, + return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", url)) } + defer resp.Body.Close() if resp.StatusCode >= 400 { - return storeError(ctx, ch, db, NewError(ErrorConnectionType, + return storeError(ctx, db, ch, false, NewError(ErrorConnectionType, "error doing http GET for url %s with status code %d", url, resp.StatusCode)) } - defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { return WrapErrorISE(err, "error reading "+ "response body for url %s", url) } - keyAuth := strings.Trim(string(body), "\r\n") + keyAuth := strings.TrimSpace(string(body)) expected, err := KeyAuthorization(ch.Token, jwk) if err != nil { return err } if keyAuth != expected { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expected, keyAuth)) } @@ -107,7 +108,11 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { config := &tls.Config{ - NextProtos: []string{"acme-tls/1"}, + NextProtos: []string{"acme-tls/1"}, + // https://tools.ietf.org/html/rfc8737#section-4 + // ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2 + // [RFC5246] or higher when connecting to clients for validation. + MinVersion: tls.VersionTLS12, ServerName: ch.Value, InsecureSkipVerify: true, // we expect a self-signed challenge certificate } @@ -116,7 +121,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON conn, err := vo.TLSDial("tcp", hostPort, config) if err != nil { - return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err, + return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing TLS dial for %s", hostPort)) } defer conn.Close() @@ -125,19 +130,19 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON certs := cs.PeerCertificates if len(certs) == 0 { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "%s challenge for %s resulted in no certificates", ch.Type, ch.Value)) } if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) } leafCert := certs[0] if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)) } @@ -154,7 +159,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON for _, ext := range leafCert.Extensions { if idPeAcmeIdentifier.Equal(ext.Id) { if !ext.Critical { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) } @@ -162,12 +167,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON rest, err := asn1.Unmarshal(ext.Value, &extValue) if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) } if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ "expected acmeValidationV1 extension value %s for this challenge but got %s", hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue))) @@ -189,11 +194,11 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON } if foundIDPeAcmeIdentifierV1Obsolete { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) } - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } @@ -206,7 +211,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) if err != nil { - return storeError(ctx, ch, db, WrapError(ErrorDNSType, err, + return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) } @@ -224,7 +229,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK } } if !found { - return storeError(ctx, ch, db, NewError(ErrorRejectedIdentifierType, + return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords)) } @@ -251,8 +256,11 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { } // storeError the given error to an ACME error and saves using the DB interface. -func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { +func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err *Error) error { ch.Error = err + if markInvalid { + ch.Status = StatusInvalid + } if err := db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "failure saving error to acme challenge") } diff --git a/acme/challenge_test.go b/acme/challenge_test.go index caaca8f6..14287945 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -31,17 +31,19 @@ import ( func Test_storeError(t *testing.T) { type test struct { - ch *Challenge - db DB - err *Error + ch *Challenge + db DB + markInvalid bool + err *Error } err := NewError(ErrorMalformedType, "foo") tests := map[string]func(t *testing.T) test{ "fail/db.UpdateChallenge-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, } return test{ ch: ch, @@ -50,6 +52,7 @@ func Test_storeError(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -64,9 +67,10 @@ func Test_storeError(t *testing.T) { }, "fail/db.UpdateChallenge-acme-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, } return test{ ch: ch, @@ -75,6 +79,7 @@ func Test_storeError(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -89,9 +94,10 @@ func Test_storeError(t *testing.T) { }, "ok": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, } return test{ ch: ch, @@ -100,6 +106,7 @@ func Test_storeError(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -111,11 +118,38 @@ func Test_storeError(t *testing.T) { }, } }, + "ok/mark-invalid": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, + } + return test{ + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusInvalid) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + markInvalid: true, + } + }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := storeError(context.Background(), tc.ch, tc.db, err); err != nil { + if err := storeError(context.Background(), tc.db, tc.ch, tc.markInvalid, err); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -499,9 +533,10 @@ func TestHTTP01Validate(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/http-get-error-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } return test{ @@ -515,6 +550,8 @@ func TestHTTP01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) @@ -530,9 +567,10 @@ func TestHTTP01Validate(t *testing.T) { }, "ok/http-get-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } return test{ @@ -546,6 +584,8 @@ func TestHTTP01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) @@ -560,9 +600,10 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/http-get->=400-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } return test{ @@ -571,6 +612,7 @@ func TestHTTP01Validate(t *testing.T) { HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, + Body: errReader(0), }, nil }, }, @@ -578,6 +620,8 @@ func TestHTTP01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) @@ -593,9 +637,10 @@ func TestHTTP01Validate(t *testing.T) { }, "ok/http-get->=400": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } return test{ @@ -604,6 +649,7 @@ func TestHTTP01Validate(t *testing.T) { HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, + Body: errReader(0), }, nil }, }, @@ -611,6 +657,8 @@ func TestHTTP01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) @@ -625,9 +673,10 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/read-body": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } return test{ @@ -644,9 +693,10 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -667,9 +717,10 @@ func TestHTTP01Validate(t *testing.T) { }, "ok/key-auth-mismatch": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -692,6 +743,7 @@ func TestHTTP01Validate(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusInvalid) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) @@ -707,9 +759,10 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -732,6 +785,7 @@ func TestHTTP01Validate(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusInvalid) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) @@ -748,9 +802,10 @@ func TestHTTP01Validate(t *testing.T) { }, "fail/update-challenge-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -789,9 +844,10 @@ func TestHTTP01Validate(t *testing.T) { }, "ok": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: "zap.internal", + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -864,9 +920,10 @@ func TestDNS01Validate(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/lookupTXT-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } return test{ @@ -880,6 +937,8 @@ func TestDNS01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) @@ -896,9 +955,10 @@ func TestDNS01Validate(t *testing.T) { }, "ok/lookupTXT-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } return test{ @@ -912,6 +972,8 @@ func TestDNS01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) @@ -927,9 +989,10 @@ func TestDNS01Validate(t *testing.T) { }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -949,9 +1012,10 @@ func TestDNS01Validate(t *testing.T) { }, "fail/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -971,6 +1035,8 @@ func TestDNS01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) @@ -988,9 +1054,10 @@ func TestDNS01Validate(t *testing.T) { }, "ok/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1010,6 +1077,8 @@ func TestDNS01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) @@ -1026,9 +1095,10 @@ func TestDNS01Validate(t *testing.T) { }, "fail/update-challenge-error": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1051,6 +1121,7 @@ func TestDNS01Validate(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) assert.Equals(t, updch.Status, StatusValid) assert.Equals(t, updch.Error, nil) @@ -1069,9 +1140,10 @@ func TestDNS01Validate(t *testing.T) { }, "ok": func(t *testing.T) test { ch := &Challenge{ - ID: "chID", - Token: "token", - Value: fulldomain, + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1094,6 +1166,7 @@ func TestDNS01Validate(t *testing.T) { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) assert.Equals(t, updch.Status, StatusValid) assert.Equals(t, updch.Error, nil) @@ -1349,7 +1422,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1379,7 +1452,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1415,7 +1488,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1452,7 +1525,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1496,7 +1569,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1539,7 +1612,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1583,7 +1656,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1626,7 +1699,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1692,7 +1765,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1731,7 +1804,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1775,7 +1848,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1818,7 +1891,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1858,7 +1931,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1897,7 +1970,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1942,7 +2015,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -1988,7 +2061,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -2034,7 +2107,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -2078,7 +2151,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusInvalid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) @@ -2123,7 +2196,7 @@ func TestTLSALPN01Validate(t *testing.T) { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equals(t, updch.ID, ch.ID) assert.Equals(t, updch.Token, ch.Token) - assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Status, StatusValid) assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) assert.Equals(t, updch.Error, nil) diff --git a/acme/common.go b/acme/common.go index 05c909eb..26552c61 100644 --- a/acme/common.go +++ b/acme/common.go @@ -19,7 +19,7 @@ type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { - return time.Now().UTC().Round(time.Second) + return time.Now().UTC().Truncate(time.Second) } var clock Clock diff --git a/acme/order.go b/acme/order.go index 7405906d..a003fe9a 100644 --- a/acme/order.go +++ b/acme/order.go @@ -24,10 +24,10 @@ type Order struct { AccountID string `json:"-"` ProvisionerID string `json:"-"` Status Status `json:"status"` - ExpiresAt time.Time `json:"expires,omitempty"` + ExpiresAt time.Time `json:"expires"` Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore,omitempty"` - NotAfter time.Time `json:"notAfter,omitempty"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` Error *Error `json:"error,omitempty"` AuthorizationIDs []string `json:"-"` AuthorizationURLs []string `json:"authorizations"` From 9aef84b9af821b8b060102aa5c42eba4cb1859ae Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 29 Mar 2021 23:02:41 -0700 Subject: [PATCH 43/47] remove unused nonce.clone method --- acme/db/nosql/nonce.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index 6344f0c6..9badae87 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -18,11 +18,6 @@ type dbNonce struct { DeletedAt time.Time } -func (dbn *dbNonce) clone() *dbNonce { - u := *dbn - return &u -} - // CreateNonce creates, stores, and returns an ACME replay-nonce. // Implements the acme.DB interface. func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { From 2e0e62bc4c82c64985e88dfa1b465f6c93ec5b2d Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 29 Mar 2021 23:16:39 -0700 Subject: [PATCH 44/47] add WriteError method for acme api --- acme/errors.go | 29 +++++++++++++++++++++++++++++ api/errors.go | 14 ++++++-------- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/acme/errors.go b/acme/errors.go index f4aa17e7..6ecf0912 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -3,8 +3,13 @@ package acme import ( "encoding/json" "fmt" + "log" + "net/http" + "os" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/logging" ) // ProblemType is the type of the ACME problem. @@ -347,3 +352,27 @@ func (e *Error) ToLog() (interface{}, error) { } return string(b), nil } + +// WriteError writes to w a JSON representation of the given error. +func WriteError(w http.ResponseWriter, err *Error) { + w.Header().Set("Content-Type", "application/problem+json") + w.WriteHeader(err.StatusCode()) + + // Write errors in the response writer + if rl, ok := w.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "error": err.Err, + }) + if os.Getenv("STEPDEBUG") == "1" { + if e, ok := err.Err.(errs.StackTracer); ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e), + }) + } + } + } + + if err := json.NewEncoder(w).Encode(err); err != nil { + log.Println(err) + } +} diff --git a/api/errors.go b/api/errors.go index 67c9ba87..fa2d6a06 100644 --- a/api/errors.go +++ b/api/errors.go @@ -14,12 +14,14 @@ import ( // WriteError writes to w a JSON representation of the given error. func WriteError(w http.ResponseWriter, err error) { - switch err.(type) { + switch k := err.(type) { case *acme.Error: - w.Header().Set("Content-Type", "application/problem+json") + acme.WriteError(w, k) + return default: w.Header().Set("Content-Type", "application/json") } + cause := errors.Cause(err) if sc, ok := err.(errs.StatusCoder); ok { w.WriteHeader(sc.StatusCode()) @@ -33,15 +35,11 @@ func WriteError(w http.ResponseWriter, err error) { // Write errors in the response writer if rl, ok := w.(logging.ResponseLogger); ok { - logErr := err - if u, ok := err.(*acme.Error); ok { - logErr = u.Err - } rl.WithFields(map[string]interface{}{ - "error": logErr, + "error": err, }) if os.Getenv("STEPDEBUG") == "1" { - if e, ok := logErr.(errs.StackTracer); ok { + if e, ok := err.(errs.StackTracer); ok { rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e), }) From 672e3f976e120394c251824160351a13996f65d8 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 12 Apr 2021 19:06:07 -0700 Subject: [PATCH 45/47] Few ACME fixes ... - always URL escape linker output - validateJWS should accept RSAPSS - GetUpdateAccount -> GetOrUpdateAccount --- acme/api/account.go | 4 +-- acme/api/account_test.go | 27 +++++++-------- acme/api/handler.go | 13 ++++---- acme/api/linker.go | 27 ++++++++------- acme/api/linker_test.go | 65 +++++++++++++++++++------------------ acme/api/middleware.go | 6 ++-- acme/api/middleware_test.go | 10 +++--- acme/api/order_test.go | 60 +++++++++++++++++----------------- ca/acmeClient_test.go | 2 -- 9 files changed, 107 insertions(+), 107 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index ae39d2f7..92c5dbfc 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -126,8 +126,8 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { api.JSONStatus(w, acc, httpStatus) } -// GetUpdateAccount is the api for updating an ACME account. -func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { +// GetOrUpdateAccount is the api for updating an ACME account. +func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 7cbe7b7c..c4d7a812 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -32,7 +32,7 @@ func newProv() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", - Name: "test@acme-provisioner.com", + Name: "test@acme-provisioner.com", } if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { fmt.Printf("%v", err) @@ -168,11 +168,6 @@ func TestUpdateAccountRequest_Validate(t *testing.T) { } func TestHandler_GetOrdersByAccountID(t *testing.T) { - oids := []string{"foo", "bar"} - oidURLs := []string{ - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/foo", - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/bar", - } accID := "account-id" // Request with chi context @@ -185,6 +180,12 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) + oids := []string{"foo", "bar"} + oidURLs := []string{ + fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName), + fmt.Sprintf("%s/acme/%s/order/bar", baseURL.String(), provName), + } + type test struct { db acme.DB ctx context.Context @@ -287,7 +288,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { func TestHandler_NewAccount(t *testing.T) { prov := newProv() - provName := url.PathEscape(prov.GetName()) + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { @@ -424,7 +425,7 @@ func TestHandler_NewAccount(t *testing.T) { Key: jwk, Status: acme.StatusValid, Contact: []string{"foo", "bar"}, - OrdersURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders", + OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), }, ctx: ctx, statusCode: 201, @@ -486,14 +487,14 @@ func TestHandler_NewAccount(t *testing.T) { assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), - provName, "accountID")}) + escProvName, "accountID")}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } -func TestHandler_GetUpdateAccount(t *testing.T) { +func TestHandler_GetOrUpdateAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ ID: accID, @@ -501,7 +502,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) { OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), } prov := newProv() - provName := url.PathEscape(prov.GetName()) + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { @@ -662,7 +663,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) { req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetUpdateAccount(w, req) + h.GetOrUpdateAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -686,7 +687,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) { assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), - provName, accID)}) + escProvName, accID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/api/handler.go b/acme/api/handler.go index e557f33b..7d02861e 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -102,7 +102,7 @@ func (h *Handler) Route(r api.Router) { } r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) + r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) @@ -125,12 +125,11 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { // Directory represents an ACME directory for configuring clients. type Directory struct { - NewNonce string `json:"newNonce,omitempty"` - NewAccount string `json:"newAccount,omitempty"` - NewOrder string `json:"newOrder,omitempty"` - NewAuthz string `json:"newAuthz,omitempty"` - RevokeCert string `json:"revokeCert,omitempty"` - KeyChange string `json:"keyChange,omitempty"` + NewNonce string `json:"newNonce"` + NewAccount string `json:"newAccount"` + NewOrder string `json:"newOrder"` + RevokeCert string `json:"revokeCert"` + KeyChange string `json:"keyChange"` } // ToLog enables response logging for the Directory type. diff --git a/acme/api/linker.go b/acme/api/linker.go index b6a44dfa..702f7433 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -44,27 +44,26 @@ func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ... // URL dynamically obtained from the request for which the link is being // calculated. func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { - var link string + var u = url.URL{} + // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 + if baseURL != nil { + u = *baseURL + } + switch typ { case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: - link = fmt.Sprintf("/%s/%s", provisionerName, typ) + u.Path = fmt.Sprintf("/%s/%s", provisionerName, typ) case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: - link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + u.Path = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) case ChallengeLinkType: - link = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + u.Path = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) case OrdersByAccountLinkType: - link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + u.Path = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) case FinalizeLinkType: - link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + u.Path = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) } if abs { - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - u := url.URL{} - if baseURL != nil { - u = *baseURL - } - // If no Scheme is set, then default to https. if u.Scheme == "" { u.Scheme = "https" @@ -75,10 +74,10 @@ func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, u.Host = l.dns } - u.Path = l.prefix + link + u.Path = l.prefix + u.Path return u.String() } - return link + return u.EscapedPath() } // LinkType captures the link type. diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go index 2252e334..6bb1f739 100644 --- a/acme/api/linker_test.go +++ b/acme/api/linker_test.go @@ -51,52 +51,53 @@ func TestLinker_GetLinkExplicit(t *testing.T) { id := "1234" prov := newProv() - provID := url.PathEscape(prov.GetName()) + provName := prov.GetName() + escProvName := url.PathEscape(provName) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-nonce", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-account", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-order", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-authz", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, false, baseURL), fmt.Sprintf("/%s/directory", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, false, baseURL), fmt.Sprintf("/%s/revoke-cert", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, false, baseURL), fmt.Sprintf("/%s/key-change", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) - assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", escProvName, id, id)) - assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) - assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", escProvName)) } func TestLinker_LinkOrder(t *testing.T) { diff --git a/acme/api/middleware.go b/acme/api/middleware.go index e06e4736..861876a9 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -90,7 +90,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ct := r.Header.Get("Content-Type") var expected []string - if strings.Contains(r.URL.Path, h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { + if strings.Contains(r.URL.String(), h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { @@ -170,7 +170,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { } hdr := sig.Protected switch hdr.Algorithm { - case jose.RS256, jose.RS384, jose.RS512: + case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512: if hdr.JSONWebKey != nil { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: @@ -189,7 +189,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", hdr.Algorithm)) + api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 1c0f3689..4c316910 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -228,9 +228,9 @@ func TestHandler_addDirLink(t *testing.T) { func TestHandler_verifyContentType(t *testing.T) { prov := newProv() - provName := prov.GetName() + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), provName) + url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { h Handler ctx context.Context @@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) { h: Handler{ linker: NewLinker("dns", "acme"), }, - url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + url: url, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, @@ -1160,7 +1160,7 @@ func TestHandler_validateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: none"), + err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), } }, "fail/unsuitable-algorithm-mac": func(t *testing.T) test { @@ -1172,7 +1172,7 @@ func TestHandler_validateJWS(t *testing.T) { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", jose.HS256), + err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), } }, "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 506f0a0a..300aa61b 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -149,6 +149,10 @@ func TestFinalizeRequestValidate(t *testing.T) { } func TestHandler_GetOrder(t *testing.T) { + prov := newProv() + escProvName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + now := clock.Now() nbf := now naf := now.Add(24 * time.Hour) @@ -171,21 +175,18 @@ func TestHandler_GetOrder(t *testing.T) { Status: acme.StatusInvalid, Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), AuthorizationURLs: []string{ - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", + fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName), }, - FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", + FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName), } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} url := fmt.Sprintf("%s/acme/%s/order/%s", - baseURL.String(), provName, o.ID) + baseURL.String(), escProvName, o.ID) type test struct { db acme.DB @@ -285,7 +286,7 @@ func TestHandler_GetOrder(t *testing.T) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{ AccountID: "accountID", - ProvisionerID: "acme/test@acme-provisioner.com", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: clock.Now().Add(-time.Hour), Status: acme.StatusReady, }, nil @@ -311,7 +312,7 @@ func TestHandler_GetOrder(t *testing.T) { return &acme.Order{ ID: "orderID", AccountID: "accountID", - ProvisionerID: "acme/test@acme-provisioner.com", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: expiry, Status: acme.StatusReady, AuthorizationIDs: []string{"foo", "bar", "baz"}, @@ -581,10 +582,10 @@ func TestHandler_newAuthorization(t *testing.T) { func TestHandler_NewOrder(t *testing.T) { // Request with chi context prov := newProv() - provName := url.PathEscape(prov.GetName()) + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} url := fmt.Sprintf("%s/acme/%s/order/ordID", - baseURL.String(), provName) + baseURL.String(), escProvName) type test struct { db acme.DB @@ -877,8 +878,8 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{ - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID", - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az2ID", + fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName), }) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) @@ -968,7 +969,7 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) @@ -1059,7 +1060,7 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) @@ -1149,7 +1150,7 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) @@ -1240,7 +1241,7 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) @@ -1291,6 +1292,10 @@ func TestHandler_NewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { + prov := newProv() + escProvName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + now := clock.Now() nbf := now naf := now.Add(24 * time.Hour) @@ -1311,22 +1316,19 @@ func TestHandler_FinalizeOrder(t *testing.T) { ExpiresAt: naf, Status: acme.StatusValid, AuthorizationURLs: []string{ - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", - "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", + fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName), }, - FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", - CertificateURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/certificate/certID", + FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName), + CertificateURL: fmt.Sprintf("%s/acme/%s/certificate/certID", baseURL.String(), escProvName), } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} url := fmt.Sprintf("%s/acme/%s/order/%s", - baseURL.String(), provName, o.ID) + baseURL.String(), escProvName, o.ID) _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") assert.FatalError(t, err) @@ -1488,7 +1490,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{ AccountID: "accountID", - ProvisionerID: "acme/test@acme-provisioner.com", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: clock.Now().Add(-time.Hour), Status: acme.StatusReady, }, nil @@ -1515,7 +1517,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { return &acme.Order{ ID: "orderID", AccountID: "accountID", - ProvisionerID: "acme/test@acme-provisioner.com", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: naf, Status: acme.StatusValid, AuthorizationIDs: []string{"foo", "bar", "baz"}, diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index b97fdbd0..f5963de4 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -35,7 +35,6 @@ func TestNewACMEClient(t *testing.T) { NewNonce: srv.URL + "/foo", NewAccount: srv.URL + "/bar", NewOrder: srv.URL + "/baz", - NewAuthz: srv.URL + "/zap", RevokeCert: srv.URL + "/zip", KeyChange: srv.URL + "/blorp", } @@ -146,7 +145,6 @@ func TestACMEClient_GetDirectory(t *testing.T) { NewNonce: "/foo", NewAccount: "/bar", NewOrder: "/baz", - NewAuthz: "/zap", RevokeCert: "/zip", KeyChange: "/blorp", }, From 63ec2e35b00c69863c8c40fdea63b5ae214103dc Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 13 Apr 2021 14:42:13 -0700 Subject: [PATCH 46/47] Change Clock to empty struct in nosql/nosql | truncate > round - saves space - --- acme/db/nosql/nosql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index b8f79edc..052f5729 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -86,11 +86,11 @@ func randID() (val string, err error) { } // Clock that returns time in UTC rounded to seconds. -type Clock int +type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { - return time.Now().UTC().Round(time.Second) + return time.Now().UTC().Truncate(time.Second) } var clock = new(Clock) From 6cfb9b790cbfab624eb4d79c85ea846f25792af3 Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 13 Apr 2021 14:53:05 -0700 Subject: [PATCH 47/47] Remove check of deprecated value - NegotiatedProtocolIsMutual is always true: Deprecated according to golang docs --- acme/challenge.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/challenge.go b/acme/challenge.go index a47fc7df..1059e437 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -134,7 +134,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "%s challenge for %s resulted in no certificates", ch.Type, ch.Value)) } - if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { + if cs.NegotiatedProtocol != "acme-tls/1" { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) }