From f34fb80eb61a4a84bc98fd32f45c450cf14d54f9 Mon Sep 17 00:00:00 2001 From: max furman Date: Tue, 20 Oct 2020 16:18:16 -0700 Subject: [PATCH] [acme] Use lock for ordersByAccID and type to house methods --- acme/account.go | 46 -------------------------- acme/authority.go | 3 +- acme/order.go | 84 +++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 80 insertions(+), 53 deletions(-) diff --git a/acme/account.go b/acme/account.go index ea0e7fdc..1c5870d5 100644 --- a/acme/account.go +++ b/acme/account.go @@ -195,49 +195,3 @@ func getAccountByKeyID(db nosql.DB, kid string) (*account, error) { } return getAccountByID(db, string(id)) } - -// getOrderIDsByAccount retrieves a list of Order IDs that were created by the -// account. -func getOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { - b, err := db.Get(ordersByAccountIDTable, []byte(accID)) - if err != nil { - if nosql.IsErrNotFound(err) { - return []string{}, nil - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) - } - var oids []string - if err := json.Unmarshal(b, &oids); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) - } - - // Remove any order that is not in PENDING state and update the stored list - // before returning. - // - // According to RFC 8555: - // The server SHOULD include pending orders and SHOULD NOT include orders - // that are invalid in the array of URLs. - pendOids := []string{} - for _, oid := range oids { - o, err := getOrder(db, oid) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) - } - if o, err = o.updateStatus(db); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) - } - if o.Status == StatusPending { - pendOids = append(pendOids, oid) - } - } - // If the number of pending orders is less than the number of orders in the - // list, then update the pending order list. - if len(pendOids) != len(oids) { - if err = orderIDs(pendOids).save(db, oids, accID); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids))) - } - } - - return pendOids, nil -} diff --git a/acme/authority.go b/acme/authority.go index 959dc9c4..d1bb0aaf 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -233,7 +233,8 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order // GetOrdersByAccount returns the list of order urls owned by the account. func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - oids, err := getOrderIDsByAccount(a.db, id) + var oiba = orderIDsByAccount{} + oids, err := oiba.getOrderIDsByAccount(a.db, id, false) if err != nil { return nil, err } diff --git a/acme/order.go b/acme/order.go index 57168419..ef5345e4 100644 --- a/acme/order.go +++ b/acme/order.go @@ -6,6 +6,7 @@ import ( "encoding/json" "sort" "strings" + "sync" "time" "github.com/pkg/errors" @@ -16,6 +17,9 @@ import ( var defaultOrderExpiry = time.Hour * 24 +// Mutex for locking ordersByAccount index operations. +var ordersByAccountMux = &sync.Mutex{} + // Order contains order metadata for the ACME protocol order type. type Order struct { Status string `json:"status"` @@ -111,17 +115,84 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { return nil, err } - // Update the "order IDs by account ID" index // - oids, err := getOrderIDsByAccount(db, ops.AccountID) + var oidHelper = orderIDsByAccount{} + _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID) if err != nil { return nil, err } - newOids := append(oids, o.ID) - if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil { - db.Del(orderTable, []byte(o.ID)) + return o, nil +} + +type orderIDsByAccount struct{} + +func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { + ordersByAccountMux.Lock() + defer ordersByAccountMux.Unlock() + + // Update the "order IDs by account ID" index + oids, err := oiba.getOrderIDsByAccount(db, accID, true) + if err != nil { return nil, err } - return o, nil + newOids := append(oids, oid) + if err = orderIDs(newOids).save(db, oids, accID); err != nil { + // Delete the entire order if storing the index fails. + db.Del(orderTable, []byte(oid)) + return nil, err + } + return newOids, nil +} + +// getOrderIDsByAccount retrieves a list of Order IDs that were created by the +// account. +func (oiba orderIDsByAccount) getOrderIDsByAccount(db nosql.DB, accID string, alreadyLocked bool) ([]string, error) { + if !alreadyLocked { + ordersByAccountMux.Lock() + + defer ordersByAccountMux.Unlock() + } + + b, err := db.Get(ordersByAccountIDTable, []byte(accID)) + if err != nil { + if nosql.IsErrNotFound(err) { + return []string{}, nil + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) + } + var oids []string + if err := json.Unmarshal(b, &oids); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) + } + + // Remove any order that is not in PENDING state and update the stored list + // before returning. + // + // According to RFC 8555: + // The server SHOULD include pending orders and SHOULD NOT include orders + // that are invalid in the array of URLs. + pendOids := []string{} + for _, oid := range oids { + o, err := getOrder(db, oid) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) + } + if o, err = o.updateStatus(db); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) + } + if o.Status == StatusPending { + pendOids = append(pendOids, oid) + } + } + // If the number of pending orders is less than the number of orders in the + // list, then update the pending order list. + if len(pendOids) != len(oids) { + if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ + "len(orderIDs) = %d", len(pendOids))) + } + } + + return pendOids, nil } type orderIDs []string @@ -271,6 +342,7 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut if o, err = o.updateStatus(db); err != nil { return nil, err } + switch o.Status { case StatusInvalid: return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))