From e3826dd1c33f53e653a243547d2d1a0d3f5a104e Mon Sep 17 00:00:00 2001 From: max furman Date: Sun, 26 May 2019 17:41:10 -0700 Subject: [PATCH] Add ACME CA capabilities --- .golangci.yml | 1 + .travis.yml | 2 +- Gopkg.lock | 2 +- acme/account.go | 214 +++ acme/account_test.go | 844 +++++++++++ acme/api/account.go | 213 +++ acme/api/account_test.go | 790 +++++++++++ acme/api/handler.go | 214 +++ acme/api/handler_test.go | 771 ++++++++++ acme/api/middleware.go | 377 +++++ acme/api/middleware_test.go | 1550 +++++++++++++++++++++ acme/api/order.go | 164 +++ acme/api/order_test.go | 757 ++++++++++ acme/authority.go | 263 ++++ acme/authority_test.go | 1474 ++++++++++++++++++++ acme/authz.go | 344 +++++ acme/authz_test.go | 809 +++++++++++ acme/certificate.go | 89 ++ acme/certificate_test.go | 253 ++++ acme/challenge.go | 445 ++++++ acme/challenge_test.go | 1093 +++++++++++++++ acme/common.go | 76 + acme/directory.go | 120 ++ acme/directory_test.go | 63 + acme/errors.go | 439 ++++++ acme/nonce.go | 73 + acme/nonce_test.go | 163 +++ acme/order.go | 342 +++++ acme/order_test.go | 1129 +++++++++++++++ api/api.go | 24 +- api/api_test.go | 8 + api/errors.go | 9 +- api/revoke.go | 2 - api/utils.go | 33 + authority/authority.go | 4 +- authority/error.go | 24 + authority/provisioner/acme.go | 85 ++ authority/provisioner/acme_test.go | 184 +++ authority/provisioner/collection.go | 5 +- authority/provisioner/collection_test.go | 9 + authority/provisioner/provisioner.go | 96 ++ authority/provisioner/sign_ssh_options.go | 6 +- authority/provisioner/utils_test.go | 12 + authority/provisioners.go | 10 + authority/testdata/certs/badsig.csr | 8 + authority/testdata/certs/foo.csr | 8 + ca/acmeClient.go | 354 +++++ ca/acmeClient_test.go | 1358 ++++++++++++++++++ ca/ca.go | 52 +- ca/client_test.go | 30 +- db/db.go | 129 +- db/db_test.go | 132 +- db/simple.go | 55 + docs/acme.md | 160 +++ 54 files changed, 15687 insertions(+), 184 deletions(-) create mode 100644 acme/account.go create mode 100644 acme/account_test.go create mode 100644 acme/api/account.go create mode 100644 acme/api/account_test.go create mode 100644 acme/api/handler.go create mode 100644 acme/api/handler_test.go create mode 100644 acme/api/middleware.go create mode 100644 acme/api/middleware_test.go create mode 100644 acme/api/order.go create mode 100644 acme/api/order_test.go create mode 100644 acme/authority.go create mode 100644 acme/authority_test.go create mode 100644 acme/authz.go create mode 100644 acme/authz_test.go create mode 100644 acme/certificate.go create mode 100644 acme/certificate_test.go create mode 100644 acme/challenge.go create mode 100644 acme/challenge_test.go create mode 100644 acme/common.go create mode 100644 acme/directory.go create mode 100644 acme/directory_test.go create mode 100644 acme/errors.go create mode 100644 acme/nonce.go create mode 100644 acme/nonce_test.go create mode 100644 acme/order.go create mode 100644 acme/order_test.go create mode 100644 authority/provisioner/acme.go create mode 100644 authority/provisioner/acme_test.go create mode 100644 authority/testdata/certs/badsig.csr create mode 100644 authority/testdata/certs/foo.csr create mode 100644 ca/acmeClient.go create mode 100644 ca/acmeClient_test.go create mode 100644 docs/acme.md diff --git a/.golangci.yml b/.golangci.yml index 706a6b0b..0ac34cf4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -60,6 +60,7 @@ issues: - declaration of "err" shadows declaration at line - should have a package comment, unless it's in another file for this package - error strings should not be capitalized or end with punctuation or a newline + - declaration of "authz" shadows declaration at line # golangci.com configuration # https://github.com/golangci/golangci/wiki/Configuration service: diff --git a/.travis.yml b/.travis.yml index b602cfc6..991bee27 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: go go: -- 1.12.x +- 1.13.x addons: apt: packages: diff --git a/Gopkg.lock b/Gopkg.lock index 68dc523b..8d6ac69b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -228,7 +228,7 @@ "utils", ] pruneopts = "UT" - revision = "e097873f958542df7505184bee0fadfcf17027de" + revision = "ae6e517f70783467afe6199e12fb43309a7e693e" [[projects]] branch = "master" diff --git a/acme/account.go b/acme/account.go new file mode 100644 index 00000000..3167dd09 --- /dev/null +++ b/acme/account.go @@ -0,0 +1,214 @@ +package acme + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql" +) + +// Account is a subset of the internal account type containing only those +// attributes required for responses in the ACME protocol. +type Account struct { + Contact []string `json:"contact,omitempty"` + Status string `json:"status"` + Orders string `json:"orders"` + ID string `json:"-"` + Key *jose.JSONWebKey `json:"-"` +} + +// ToLog enables response logging. +func (a *Account) ToLog() (interface{}, error) { + b, err := json.Marshal(a) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging")) + } + return string(b), nil +} + +// GetID returns the account ID. +func (a *Account) GetID() string { + return a.ID +} + +// GetKey returns the JWK associated with the account. +func (a *Account) GetKey() *jose.JSONWebKey { + return a.Key +} + +// IsValid returns true if the Account is valid. +func (a *Account) IsValid() bool { + return a.Status == StatusValid +} + +// AccountOptions are the options needed to create a new ACME account. +type AccountOptions struct { + Key *jose.JSONWebKey + Contact []string +} + +// account represents an ACME account. +type account struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Deactivated time.Time `json:"deactivated"` + Key *jose.JSONWebKey `json:"key"` + Contact []string `json:"contact,omitempty"` + Status string `json:"status"` +} + +// newAccount returns a new acme account type. +func newAccount(db nosql.DB, ops AccountOptions) (*account, error) { + id, err := randID() + if err != nil { + return nil, err + } + + a := &account{ + ID: id, + Key: ops.Key, + Contact: ops.Contact, + Status: "valid", + Created: clock.Now(), + } + return a, a.saveNew(db) +} + +// toACME converts the internal Account type into the public acmeAccount +// type for presentation in the ACME protocol. +func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) { + return &Account{ + Status: a.Status, + Contact: a.Contact, + Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID), + Key: a.Key, + ID: a.ID, + }, nil +} + +// save writes the Account to the DB. +// If the account is new then the necessary indices will be created. +// Else, the account in the DB will be updated. +func (a *account) saveNew(db nosql.DB) error { + kid, err := keyToID(a.Key) + if err != nil { + return err + } + kidB := []byte(kid) + + // Set the jwkID -> acme account ID index + _, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID)) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index")) + case !swapped: + return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) + default: + if err = a.save(db, nil); err != nil { + db.Del(accountByKeyIDTable, kidB) + return err + } + return nil + } +} + +func (a *account) save(db nosql.DB, old *account) error { + var ( + err error + oldB []byte + ) + if old == nil { + oldB = nil + } else { + if oldB, err = json.Marshal(old); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) + } + } + + b, err := json.Marshal(*a) + if err != nil { + return errors.Wrap(err, "error marshaling new account object") + } + // Set the Account + _, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error storing account")) + case !swapped: + return ServerInternalErr(errors.New("error storing account; " + + "value has changed since last read")) + default: + return nil + } +} + +// update updates the acme account object stored in the database if, +// and only if, the account has not changed since the last read. +func (a *account) update(db nosql.DB, contact []string) (*account, error) { + b := *a + b.Contact = contact + if err := b.save(db, a); err != nil { + return nil, err + } + return &b, nil +} + +// deactivate deactivates the acme account. +func (a *account) deactivate(db nosql.DB) (*account, error) { + b := *a + b.Status = StatusDeactivated + b.Deactivated = clock.Now() + if err := b.save(db, a); err != nil { + return nil, err + } + return &b, nil +} + +// getAccountByID retrieves the account with the given ID. +func getAccountByID(db nosql.DB, id string) (*account, error) { + ab, err := db.Get(accountTable, []byte(id)) + if err != nil { + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) + } + + a := new(account) + if err = json.Unmarshal(ab, a); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) + } + return a, nil +} + +// getAccountByKeyID retrieves Id associated with the given Kid. +func getAccountByKeyID(db nosql.DB, kid string) (*account, error) { + id, err := db.Get(accountByKeyIDTable, []byte(kid)) + if err != nil { + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) + } + return getAccountByID(db, string(id)) +} + +// getOrderIDsByAccount retrieves a list of Order IDs that were created by the +// account. +func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) { + b, err := db.Get(ordersByAccountIDTable, []byte(id)) + if err != nil { + if nosql.IsErrNotFound(err) { + return []string{}, nil + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id)) + } + var orderIDs []string + if err := json.Unmarshal(b, &orderIDs); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id)) + } + return orderIDs, nil +} diff --git a/acme/account_test.go b/acme/account_test.go new file mode 100644 index 00000000..37af69dc --- /dev/null +++ b/acme/account_test.go @@ -0,0 +1,844 @@ +package acme + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +var ( + defaultDisableRenewal = false + globalProvisionerClaims = provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + } +) + +func newProv() provisioner.Interface { + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { + fmt.Printf("%v", err) + } + return p +} + +func newAcc() (*account, error) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + if err != nil { + return nil, err + } + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + } + return newAccount(mockdb, AccountOptions{ + Key: jwk, Contact: []string{"foo", "bar"}, + }) +} + +func TestGetAccountByID(t *testing.T) { + type test struct { + id string + db nosql.DB + acc *account + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + return test{ + acc: acc, + id: acc.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)), + } + }, + "fail/db-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + return test{ + acc: acc, + id: acc.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + return test{ + acc: acc, + id: acc.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + return nil, nil + }, + }, + err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + return test{ + acc: acc, + id: acc.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acc, err := getAccountByID(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, acc.ID) + assert.Equals(t, tc.acc.Status, acc.Status) + assert.Equals(t, tc.acc.Created, acc.Created) + assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) + assert.Equals(t, tc.acc.Contact, acc.Contact) + assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) + } + } + }) + } +} + +func TestGetAccountByKeyID(t *testing.T) { + type test struct { + kid string + db nosql.DB + acc *account + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/kid-not-found": func(t *testing.T) test { + return test{ + kid: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("account with key id foo not found: not found")), + } + }, + "fail/db-error": func(t *testing.T) test { + return test{ + kid: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading key-account index: force")), + } + }, + "fail/getAccount-error": func(t *testing.T) test { + count := 0 + return test{ + kid: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if count == 0 { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte("foo")) + count++ + return []byte("bar"), nil + } + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading account bar: force")), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + count := 0 + return test{ + kid: acc.Key.KeyID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(acc.Key.KeyID)) + ret = []byte(acc.ID) + case 1: + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + ret = b + } + count++ + return ret, nil + }, + }, + acc: acc, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, acc.ID) + assert.Equals(t, tc.acc.Status, acc.Status) + assert.Equals(t, tc.acc.Created, acc.Created) + assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) + assert.Equals(t, tc.acc.Contact, acc.Contact) + assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) + } + } + }) + } +} + +func TestGetAccountIDsByAccount(t *testing.T) { + type test struct { + id string + db nosql.DB + res []string + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/not-found": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + res: []string{}, + } + }, + "fail/db-error": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, nil + }, + }, + err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), + } + }, + "ok": func(t *testing.T) test { + oids := []string{"foo", "bar", "baz"} + b, err := json.Marshal(oids) + assert.FatalError(t, err) + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return b, nil + }, + }, + res: oids, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res, oids) + } + } + }) + } +} + +func TestAccountToACME(t *testing.T) { + dir := newDirectory("ca.smallstep.com", "acme") + prov := newProv() + + type test struct { + acc *account + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + return test{acc: acc} + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + acmeAccount, err := tc.acc.toACME(nil, dir, prov) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acmeAccount.ID, tc.acc.ID) + assert.Equals(t, acmeAccount.Status, tc.acc.Status) + assert.Equals(t, acmeAccount.Contact, tc.acc.Contact) + assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID) + assert.Equals(t, acmeAccount.Orders, + fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID)) + } + } + }) + } +} + +func TestAccountSave(t *testing.T) { + type test struct { + acc, old *account + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + return test{ + acc: acc, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing account: force")), + } + }, + "fail/old-nil/swap-false": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + return test{ + acc: acc, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil + }, + }, + err: ServerInternalErr(errors.New("error storing account; value has changed since last read")), + } + }, + "ok/old-nil": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + return test{ + acc: acc, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, accountTable) + assert.Equals(t, []byte(acc.ID), key) + return nil, true, nil + }, + }, + } + }, + "ok/old-not-nil": func(t *testing.T) test { + oldAcc, err := newAcc() + assert.FatalError(t, err) + acc, err := newAcc() + assert.FatalError(t, err) + + oldb, err := json.Marshal(oldAcc) + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + return test{ + acc: acc, + old: oldAcc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, oldb) + assert.Equals(t, newval, b) + assert.Equals(t, bucket, accountTable) + assert.Equals(t, []byte(acc.ID), key) + return []byte("foo"), true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.acc.save(tc.db, tc.old); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestAccountSaveNew(t *testing.T) { + type test struct { + acc *account + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/keyToID-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + acc.Key.Key = "foo" + return test{ + acc: acc, + err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), + } + }, + "fail/swap-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + kid, err := keyToID(acc.Key) + assert.FatalError(t, err) + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + assert.Equals(t, old, nil) + assert.Equals(t, newval, []byte(acc.ID)) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), + } + }, + "fail/swap-false": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + kid, err := keyToID(acc.Key) + assert.FatalError(t, err) + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + assert.Equals(t, old, nil) + assert.Equals(t, newval, []byte(acc.ID)) + return nil, false, nil + }, + }, + err: ServerInternalErr(errors.New("key-id to account-id index already exists")), + } + }, + "fail/save-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + kid, err := keyToID(acc.Key) + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + count := 0 + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 0 { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + assert.Equals(t, old, nil) + assert.Equals(t, newval, []byte(acc.ID)) + count++ + return nil, true, nil + } + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + assert.Equals(t, old, nil) + assert.Equals(t, newval, b) + return nil, false, errors.New("force") + }, + MDel: func(bucket, key []byte) error { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + return nil + }, + }, + err: ServerInternalErr(errors.New("error storing account: force")), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + kid, err := keyToID(acc.Key) + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + count := 0 + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 0 { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + assert.Equals(t, old, nil) + assert.Equals(t, newval, []byte(acc.ID)) + count++ + return nil, true, nil + } + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + assert.Equals(t, old, nil) + assert.Equals(t, newval, b) + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.acc.saveNew(tc.db); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestAccountUpdate(t *testing.T) { + type test struct { + acc *account + contact []string + db nosql.DB + res []byte + err *Error + } + contact := []string{"foo", "bar"} + tests := map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + oldb, err := json.Marshal(acc) + assert.FatalError(t, err) + + _acc := *acc + clone := &_acc + clone.Contact = contact + b, err := json.Marshal(clone) + assert.FatalError(t, err) + return test{ + acc: acc, + contact: contact, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, b) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing account: force")), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + oldb, err := json.Marshal(acc) + assert.FatalError(t, err) + + _acc := *acc + clone := &_acc + clone.Contact = contact + b, err := json.Marshal(clone) + assert.FatalError(t, err) + return test{ + acc: acc, + contact: contact, + res: b, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, b) + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + acc, err := tc.acc.update(tc.db, tc.contact) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + b, err := json.Marshal(acc) + assert.FatalError(t, err) + assert.Equals(t, b, tc.res) + } + } + }) + } +} + +func TestAccountDeactivate(t *testing.T) { + type test struct { + acc *account + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + oldb, err := json.Marshal(acc) + assert.FatalError(t, err) + + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + assert.Equals(t, old, oldb) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing account: force")), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + oldb, err := json.Marshal(acc) + assert.FatalError(t, err) + + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + assert.Equals(t, old, oldb) + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + acc, err := tc.acc.deactivate(tc.db) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.acc.ID) + assert.Equals(t, acc.Contact, tc.acc.Contact) + assert.Equals(t, acc.Status, StatusDeactivated) + assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID) + assert.Equals(t, acc.Created, tc.acc.Created) + + assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute))) + assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute))) + } + } + }) + } +} + +func TestNewAccount(t *testing.T) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + kid, err := keyToID(jwk) + assert.FatalError(t, err) + ops := AccountOptions{ + Key: jwk, + Contact: []string{"foo", "bar"}, + } + type test struct { + ops AccountOptions + db nosql.DB + err *Error + id *string + } + tests := map[string]func(t *testing.T) test{ + "fail/store-error": func(t *testing.T) test { + return test{ + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), + } + }, + "ok": func(t *testing.T) test { + var _id string + id := &_id + count := 0 + return test{ + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + switch count { + case 0: + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + case 1: + assert.Equals(t, bucket, accountTable) + *id = string(key) + } + count++ + return nil, true, nil + }, + }, + id: id, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + acc, err := newAccount(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, *tc.id) + assert.Equals(t, acc.Status, StatusValid) + assert.Equals(t, acc.Contact, ops.Contact) + assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID) + + assert.True(t, acc.Deactivated.IsZero()) + + assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute))) + } + } + }) + } +} diff --git a/acme/api/account.go b/acme/api/account.go new file mode 100644 index 00000000..05d6a084 --- /dev/null +++ b/acme/api/account.go @@ -0,0 +1,213 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/logging" +) + +// NewAccountRequest represents the payload for a new account request. +type NewAccountRequest struct { + Contact []string `json:"contact"` + OnlyReturnExisting bool `json:"onlyReturnExisting"` + TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"` +} + +func validateContacts(cs []string) error { + for _, c := range cs { + if len(c) == 0 { + return acme.MalformedErr(errors.New("contact cannot be empty string")) + } + } + return nil +} + +// Validate validates a new-account request body. +func (n *NewAccountRequest) Validate() error { + if n.OnlyReturnExisting && len(n.Contact) > 0 { + return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone")) + } + return validateContacts(n.Contact) +} + +// UpdateAccountRequest represents an update-account request. +type UpdateAccountRequest struct { + Contact []string `json:"contact"` + Status string `json:"status"` +} + +// IsDeactivateRequest returns true if the update request is a deactivation +// request, false otherwise. +func (u *UpdateAccountRequest) IsDeactivateRequest() bool { + return u.Status == acme.StatusDeactivated +} + +// Validate validates a update-account request body. +func (u *UpdateAccountRequest) Validate() error { + switch { + case len(u.Status) > 0 && len(u.Contact) > 0: + return acme.MalformedErr(errors.New("incompatible input; contact and " + + "status updates are mutually exclusive")) + case len(u.Contact) > 0: + if err := validateContacts(u.Contact); err != nil { + return err + } + return nil + case len(u.Status) > 0: + if u.Status != acme.StatusDeactivated { + return acme.MalformedErr(errors.Errorf("cannot update account "+ + "status to %s, only deactivated", u.Status)) + } + return nil + default: + return acme.MalformedErr(errors.Errorf("empty update request")) + } +} + +// NewAccount is the handler resource for creating new ACME accounts. +func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + payload, err := payloadFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + var nar NewAccountRequest + if err := json.Unmarshal(payload.value, &nar); err != nil { + api.WriteError(w, acme.MalformedErr(errors.Wrap(err, + "failed to unmarshal new-account request payload"))) + return + } + if err := nar.Validate(); err != nil { + api.WriteError(w, err) + return + } + + httpStatus := http.StatusCreated + acc, err := accountFromContext(r) + if err != nil { + acmeErr, ok := err.(*acme.Error) + if !ok || acmeErr.Status != http.StatusNotFound { + // Something went wrong ... + api.WriteError(w, err) + return + } + + // Account does not exist // + if nar.OnlyReturnExisting { + api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + return + } + jwk, err := jwkFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + + if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{ + Key: jwk, + Contact: nar.Contact, + }); err != nil { + api.WriteError(w, err) + return + } + } else { + // Account exists // + httpStatus = http.StatusOK + } + + w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, + acme.URLSafeProvisionerName(prov), true, acc.GetID())) + api.JSONStatus(w, acc, httpStatus) + return +} + +// GetUpdateAccount is the api for updating an ACME account. +func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + payload, err := payloadFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + + if !payload.isPostAsGet { + var uar UpdateAccountRequest + if err := json.Unmarshal(payload.value, &uar); err != nil { + api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload"))) + return + } + if err := uar.Validate(); err != nil { + api.WriteError(w, err) + return + } + var err error + if uar.IsDeactivateRequest() { + acc, err = h.Auth.DeactivateAccount(prov, acc.GetID()) + } else { + acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact) + } + if err != nil { + api.WriteError(w, err) + return + } + } + w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID())) + api.JSON(w, acc) + return +} + +func logOrdersByAccount(w http.ResponseWriter, oids []string) { + if rl, ok := w.(logging.ResponseLogger); ok { + m := map[string]interface{}{ + "orders": oids, + } + rl.WithFields(m) + } +} + +// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. +func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + + accID := chi.URLParam(r, "accID") + if acc.ID != accID { + api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) + return + } + orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + api.JSON(w, orders) + logOrdersByAccount(w, orders) + return +} diff --git a/acme/api/account_test.go b/acme/api/account_test.go new file mode 100644 index 00000000..193088f2 --- /dev/null +++ b/acme/api/account_test.go @@ -0,0 +1,790 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" +) + +var ( + defaultDisableRenewal = false + globalProvisionerClaims = provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + } +) + +func newProv() provisioner.Interface { + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { + fmt.Printf("%v", err) + } + return p +} + +func TestNewAccountRequestValidate(t *testing.T) { + type test struct { + nar *NewAccountRequest + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/incompatible-input": func(t *testing.T) test { + return test{ + nar: &NewAccountRequest{ + OnlyReturnExisting: true, + Contact: []string{"foo", "bar"}, + }, + err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")), + } + }, + "fail/bad-contact": func(t *testing.T) test { + return test{ + nar: &NewAccountRequest{ + Contact: []string{"foo", ""}, + }, + err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + } + }, + "ok": func(t *testing.T) test { + return test{ + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + }, + } + }, + "ok/onlyReturnExisting": func(t *testing.T) test { + return test{ + nar: &NewAccountRequest{ + OnlyReturnExisting: true, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + if err := tc.nar.Validate(); err != nil { + if assert.NotNil(t, err) { + ae, ok := err.(*acme.Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestUpdateAccountRequestValidate(t *testing.T) { + type test struct { + uar *UpdateAccountRequest + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/incompatible-input": func(t *testing.T) test { + return test{ + uar: &UpdateAccountRequest{ + Contact: []string{"foo", "bar"}, + Status: "foo", + }, + err: acme.MalformedErr(errors.Errorf("incompatible input; " + + "contact and status updates are mutually exclusive")), + } + }, + "fail/bad-contact": func(t *testing.T) test { + return test{ + uar: &UpdateAccountRequest{ + Contact: []string{"foo", ""}, + }, + err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + } + }, + "fail/bad-status": func(t *testing.T) test { + return test{ + uar: &UpdateAccountRequest{ + Status: "foo", + }, + err: acme.MalformedErr(errors.Errorf("cannot update account " + + "status to foo, only deactivated")), + } + }, + "ok/contact": func(t *testing.T) test { + return test{ + uar: &UpdateAccountRequest{ + Contact: []string{"foo", "bar"}, + }, + } + }, + "ok/status": func(t *testing.T) test { + return test{ + uar: &UpdateAccountRequest{ + Status: "deactivated", + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + if err := tc.uar.Validate(); err != nil { + if assert.NotNil(t, err) { + ae, ok := err.(*acme.Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestHandlerGetOrdersByAccount(t *testing.T) { + oids := []string{ + "https://ca.smallstep.com/acme/order/foo", + "https://ca.smallstep.com/acme/order/bar", + } + accID := "account-id" + prov := newProv() + + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("accID", accID) + url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + auth: &mockAcmeAuthority{}, + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "foo"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{}, + ctx: ctx, + statusCode: 401, + problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")), + } + }, + "fail/getOrdersByAccount-error": func(t *testing.T) test { + acc := &acme.Account{ID: accID} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.ServerInternalErr(errors.New("force")), + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: accID} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) { + assert.Equals(t, p, prov) + assert.Equals(t, id, acc.ID) + return oids, nil + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetOrdersByAccount(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(oids) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} + +func TestHandlerNewAccount(t *testing.T) { + accID := "accountID" + acc := acme.Account{ + ID: accID, + Status: "valid", + Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), + } + prov := newProv() + + url := "https://ca.smallstep.com/acme/new-account" + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-payload": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/nil-payload": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/unmarshal-payload-error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + } + }, + "fail/malformed-payload-error": func(t *testing.T) test { + nar := &NewAccountRequest{ + Contact: []string{"foo", ""}, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + } + }, + "fail/no-existing-account": func(t *testing.T) test { + nar := &NewAccountRequest{ + OnlyReturnExisting: true, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/no-jwk": func(t *testing.T) test { + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + } + }, + "fail/nil-jwk": func(t *testing.T) test { + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + } + }, + "fail/NewAccount-error": func(t *testing.T) test { + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + return test{ + auth: &mockAcmeAuthority{ + newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, ops.Contact, nar.Contact) + assert.Equals(t, ops.Key, jwk) + return nil, acme.ServerInternalErr(errors.New("force")) + }, + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok/new-account": func(t *testing.T) test { + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + return test{ + auth: &mockAcmeAuthority{ + newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, ops.Contact, nar.Contact) + assert.Equals(t, ops.Key, jwk) + return &acc, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, in, []string{accID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID) + }, + }, + ctx: ctx, + statusCode: 201, + } + }, + "ok/return-existing": func(t *testing.T) test { + nar := &NewAccountRequest{ + OnlyReturnExisting: true, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, accContextKey, &acc) + return test{ + auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, in, []string{accID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID) + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.NewAccount(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(acc) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Location"], + []string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID)}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} + +func TestHandlerGetUpdateAccount(t *testing.T) { + accID := "accountID" + acc := acme.Account{ + ID: accID, + Status: "valid", + Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), + } + prov := newProv() + + // Request with chi context + url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/no-payload": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/nil-payload": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/unmarshal-payload-error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + } + }, + "fail/malformed-payload-error": func(t *testing.T) test { + uar := &UpdateAccountRequest{ + Contact: []string{"foo", ""}, + } + b, err := json.Marshal(uar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + } + }, + "fail/Deactivate-error": func(t *testing.T) test { + uar := &UpdateAccountRequest{ + Status: "deactivated", + } + b, err := json.Marshal(uar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, id, accID) + return nil, acme.ServerInternalErr(errors.New("force")) + }, + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "fail/UpdateAccount-error": func(t *testing.T) test { + uar := &UpdateAccountRequest{ + Contact: []string{"foo", "bar"}, + } + b, err := json.Marshal(uar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, id, accID) + assert.Equals(t, contacts, uar.Contact) + return nil, acme.ServerInternalErr(errors.New("force")) + }, + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok/deactivate": func(t *testing.T) test { + uar := &UpdateAccountRequest{ + Status: "deactivated", + } + b, err := json.Marshal(uar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, id, accID) + return &acc, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{accID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID) + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + "ok/new-account": func(t *testing.T) test { + uar := &UpdateAccountRequest{ + Contact: []string{"foo", "bar"}, + } + b, err := json.Marshal(uar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, id, accID) + assert.Equals(t, contacts, uar.Contact) + return &acc, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{accID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID) + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + "ok/post-as-get": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + return test{ + auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{accID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID) + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetUpdateAccount(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(acc) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Location"], + []string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", + acme.URLSafeProvisionerName(prov), accID)}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} diff --git a/acme/api/handler.go b/acme/api/handler.go new file mode 100644 index 00000000..423c08ea --- /dev/null +++ b/acme/api/handler.go @@ -0,0 +1,214 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" +) + +func link(url, typ string) string { + return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) +} + +type contextKey string + +const ( + accContextKey = contextKey("acc") + jwsContextKey = contextKey("jws") + jwkContextKey = contextKey("jwk") + payloadContextKey = contextKey("payload") + provisionerContextKey = contextKey("provisioner") +) + +type payloadInfo struct { + value []byte + isPostAsGet bool + isEmptyJSON bool +} + +func accountFromContext(r *http.Request) (*acme.Account, error) { + val, ok := r.Context().Value(accContextKey).(*acme.Account) + if !ok || val == nil { + return nil, acme.AccountDoesNotExistErr(nil) + } + return val, nil +} +func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) { + val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey) + if !ok || val == nil { + return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context")) + } + return val, nil +} +func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) { + val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature) + if !ok || val == nil { + return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context")) + } + return val, nil +} +func payloadFromContext(r *http.Request) (*payloadInfo, error) { + val, ok := r.Context().Value(payloadContextKey).(*payloadInfo) + if !ok || val == nil { + return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context")) + } + return val, nil +} +func provisionerFromContext(r *http.Request) (provisioner.Interface, error) { + val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface) + if !ok || val == nil { + return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")) + } + return val, nil +} + +// New returns a new ACME API router. +func New(acmeAuth acme.Interface) api.RouterHandler { + return &Handler{acmeAuth} +} + +// Handler is the ACME request handler. +type Handler struct { + Auth acme.Interface +} + +// Route traffic and implement the Router interface. +func (h *Handler) Route(r api.Router) { + getLink := h.Auth.GetLink + // Standard ACME API + r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce))) + r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce))) + r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory))) + r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory))) + + extractPayloadByJWK := func(next nextHTTP) nextHTTP { + return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))) + } + extractPayloadByKid := func(next nextHTTP) nextHTTP { + return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))) + } + + r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount)) + r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) + r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder)) + r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) + r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) + r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) + r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) + r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge)) + r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) +} + +// GetNonce just sets the right header since a Nonce is added to each response +// by middleware by default. +func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusNoContent) + } + return +} + +// GetDirectory is the ACME resource for returning a directory configuration +// for client configuration. +func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + dir := h.Auth.GetDirectory(prov) + api.JSON(w, dir) + return +} + +// GetAuthz ACME api for retrieving an Authz. +func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID")) + if err != nil { + api.WriteError(w, err) + return + } + + w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID())) + api.JSON(w, authz) + return +} + +// GetChallenge ACME api for retrieving a Challenge. +func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + // Just verify that the payload was set, since we're not strictly adhering + // to ACME V2 spec for reasons specified below. + _, err = payloadFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + + // NOTE: We should be checking that the request is either a POST-as-GET, or + // that the payload is an empty JSON block ({}). However, older ACME clients + // still send a vestigial body (rather than an empty JSON block) and + // strict enforcement would render these clients broken. For the time being + // we'll just ignore the body. + var ( + ch *acme.Challenge + chID = chi.URLParam(r, "chID") + ) + ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey()) + if err != nil { + api.WriteError(w, err) + return + } + + getLink := h.Auth.GetLink + w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up")) + w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID())) + api.JSON(w, ch) + return +} + +// GetCertificate ACME api for retrieving a Certificate. +func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + certID := chi.URLParam(r, "certID") + certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID) + if err != nil { + api.WriteError(w, err) + return + } + + w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8") + w.Write(certBytes) + return +} diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go new file mode 100644 index 00000000..cf256eb2 --- /dev/null +++ b/acme/api/handler_test.go @@ -0,0 +1,771 @@ +package api + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io/ioutil" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/jose" +) + +type mockAcmeAuthority struct { + deactivateAccount func(provisioner.Interface, string) (*acme.Account, error) + finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error) + getAccount func(p provisioner.Interface, id string) (*acme.Account, error) + getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error) + getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error) + getCertificate func(accID string, id string) ([]byte, error) + getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error) + getDirectory func(provisioner.Interface) *acme.Directory + getLink func(acme.Link, string, bool, ...string) string + getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error) + getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error) + loadProvisionerByID func(string) (provisioner.Interface, error) + newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error) + newNonce func() (string, error) + newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error) + updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error) + useNonce func(string) error + validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) + ret1 interface{} + err error +} + +func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) { + if m.deactivateAccount != nil { + return m.deactivateAccount(p, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Account), m.err +} + +func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { + if m.finalizeOrder != nil { + return m.finalizeOrder(p, accID, id, csr) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Order), m.err +} + +func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) { + if m.getAccount != nil { + return m.getAccount(p, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Account), m.err +} + +func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + if m.getAccountByKey != nil { + return m.getAccountByKey(p, jwk) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Account), m.err +} + +func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) { + if m.getAuthz != nil { + return m.getAuthz(p, accID, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Authz), m.err +} + +func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) { + if m.getCertificate != nil { + return m.getCertificate(accID, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.([]byte), m.err +} + +func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) { + if m.getChallenge != nil { + return m.getChallenge(p, accID, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Challenge), m.err +} + +func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory { + if m.getDirectory != nil { + return m.getDirectory(p) + } + return m.ret1.(*acme.Directory) +} + +func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string { + if m.getLink != nil { + return m.getLink(typ, provID, abs, in...) + } + return m.ret1.(string) +} + +func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) { + if m.getOrder != nil { + return m.getOrder(p, accID, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Order), m.err +} + +func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) { + if m.getOrdersByAccount != nil { + return m.getOrdersByAccount(p, id) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.([]string), m.err +} + +func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { + if m.loadProvisionerByID != nil { + return m.loadProvisionerByID(provID) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { + if m.newAccount != nil { + return m.newAccount(p, ops) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Account), m.err +} + +func (m *mockAcmeAuthority) NewNonce() (string, error) { + if m.newNonce != nil { + return m.newNonce() + } else if m.err != nil { + return "", m.err + } + return m.ret1.(string), m.err +} + +func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + if m.newOrder != nil { + return m.newOrder(p, ops) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Order), m.err +} + +func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) { + if m.updateAccount != nil { + return m.updateAccount(p, id, contact) + } else if m.err != nil { + return nil, m.err + } + return m.ret1.(*acme.Account), m.err +} + +func (m *mockAcmeAuthority) UseNonce(nonce string) error { + if m.useNonce != nil { + return m.useNonce(nonce) + } + return m.err +} + +func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { + switch { + case m.validateChallenge != nil: + return m.validateChallenge(p, accID, id, jwk) + case m.err != nil: + return nil, m.err + default: + return m.ret1.(*acme.Challenge), m.err + } +} + +func TestHandlerGetNonce(t *testing.T) { + tests := []struct { + name string + statusCode int + }{ + {"GET", 204}, + {"HEAD", 200}, + } + + // Request with chi context + req := httptest.NewRequest("GET", "http://ca.smallstep.com/nonce", nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(nil).(*Handler) + w := httptest.NewRecorder() + req.Method = tt.name + h.GetNonce(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("Handler.GetNonce StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + }) + } +} + +func TestHandlerGetDirectory(t *testing.T) { + auth := acme.NewAuthority(nil, "ca.smallstep.com", "acme", nil) + prov := newProv() + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov)) + + expDir := acme.Directory{ + NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)), + NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)), + NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)), + RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)), + KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)), + } + + type test struct { + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetDirectory(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + var dir acme.Directory + json.Unmarshal(bytes.TrimSpace(body), &dir) + assert.Equals(t, dir, expDir) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} + +func TestHandlerGetAuthz(t *testing.T) { + expiry := time.Now().UTC().Add(6 * time.Hour) + az := acme.Authz{ + ID: "authzID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "example.com", + }, + Status: "pending", + Expires: expiry.Format(time.RFC3339), + Wildcard: false, + Challenges: []*acme.Challenge{ + { + Type: "http-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chHTTPID", + ID: "chHTTP01ID", + AuthzID: "authzID", + }, + { + Type: "dns-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chDNSID", + ID: "chDNSID", + AuthzID: "authzID", + }, + }, + } + prov := newProv() + + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("authzID", az.ID) + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s", + acme.URLSafeProvisionerName(prov), az.ID) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + auth: &mockAcmeAuthority{}, + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/getAuthz-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.ServerInternalErr(errors.New("force")), + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, az.ID) + return &az, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, typ, acme.AuthzLink) + assert.True(t, abs) + assert.Equals(t, in, []string{az.ID}) + return url + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetAuthz(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + //var gotAz acme.Authz + //assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &gotAz)) + expB, err := json.Marshal(az) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} + +func TestHandlerGetCertificate(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + certBytes := append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: inter.Raw, + })...) + certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: root.Raw, + })...) + certID := "certID" + + prov := newProv() + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("certID", certID) + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s", + acme.URLSafeProvisionerName(prov), certID) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-account": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), accContextKey, nil) + return test{ + auth: &mockAcmeAuthority{}, + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/getCertificate-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.ServerInternalErr(errors.New("force")), + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + getCertificate: func(accID, id string) ([]byte, error) { + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, certID) + return certBytes, nil + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetCertificate(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes)) + assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"}) + } + }) + } +} + +func ch() acme.Challenge { + return acme.Challenge{ + Type: "http-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chID", + ID: "chID", + AuthzID: "authzID", + } +} + +func TestHandlerGetChallenge(t *testing.T) { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("chID", "chID") + url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID") + prov := newProv() + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + ch acme.Challenge + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/no-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/nil-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/validate-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.UnauthorizedErr(nil), + }, + ctx: ctx, + statusCode: 401, + problem: acme.UnauthorizedErr(nil), + } + }, + "fail/get-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.UnauthorizedErr(nil), + }, + ctx: ctx, + statusCode: 401, + problem: acme.UnauthorizedErr(nil), + } + }, + "ok/validate-challenge": func(t *testing.T) test { + key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ID: "accID", Key: key} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ch := ch() + ch.Status = "valid" + ch.Validated = time.Now().UTC().Format(time.RFC3339) + count := 0 + return test{ + auth: &mockAcmeAuthority{ + validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, ch.ID) + assert.Equals(t, jwk.KeyID, key.KeyID) + return &ch, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + var ret string + switch count { + case 0: + assert.Equals(t, typ, acme.AuthzLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.AuthzID}) + ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID) + case 1: + assert.Equals(t, typ, acme.ChallengeLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.ID}) + ret = url + } + count++ + return ret + }, + }, + ctx: ctx, + statusCode: 200, + ch: ch, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetChallenge(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(tc.ch) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf(";rel=\"up\"", tc.ch.AuthzID)}) + assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} diff --git a/acme/api/middleware.go b/acme/api/middleware.go new file mode 100644 index 00000000..3f4c99a5 --- /dev/null +++ b/acme/api/middleware.go @@ -0,0 +1,377 @@ +package api + +import ( + "context" + "crypto/rsa" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/logging" + "github.com/smallstep/cli/crypto/keys" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql" +) + +type nextHTTP = func(http.ResponseWriter, *http.Request) + +func logNonce(w http.ResponseWriter, nonce string) { + if rl, ok := w.(logging.ResponseLogger); ok { + m := map[string]interface{}{ + "nonce": nonce, + } + rl.WithFields(m) + } +} + +// addNonce is a middleware that adds a nonce to the response header. +func (h *Handler) addNonce(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + nonce, err := h.Auth.NewNonce() + if err != nil { + api.WriteError(w, err) + return + } + w.Header().Set("Replay-Nonce", nonce) + w.Header().Set("Cache-Control", "no-store") + logNonce(w, nonce) + next(w, r) + return + } +} + +// addDirLink is a middleware that adds a 'Link' response reader with the +// directory index url. +func (h *Handler) addDirLink(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index")) + next(w, r) + return + } +} + +// verifyContentType is a middleware that verifies that content type is +// application/jose+json. +func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + ct := r.Header.Get("Content-Type") + var expected []string + if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) { + // GET /certificate requests allow a greater range of content types. + expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} + } else { + // By default every request should have content-type applictaion/jose+json. + expected = []string{"application/jose+json"} + } + for _, e := range expected { + if ct == e { + next(w, r) + return + } + } + api.WriteError(w, acme.MalformedErr(errors.Errorf( + "expected content-type to be in %s, but got %s", expected, ct))) + return + } +} + +// parseJWS is a middleware that parses a request body into a JSONWebSignature struct. +func (h *Handler) parseJWS(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body"))) + return + } + jws, err := jose.ParseJWS(string(body)) + if err != nil { + api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) + return + } + ctx := context.WithValue(r.Context(), jwsContextKey, jws) + next(w, r.WithContext(ctx)) + return + } +} + +// validateJWS checks the request body for to verify that it meets ACME +// requirements for a JWS. +// +// The JWS MUST NOT have multiple signatures +// The JWS Unencoded Payload Option [RFC7797] MUST NOT be used +// The JWS Unprotected Header [RFC7515] MUST NOT be used +// The JWS Payload MUST NOT be detached +// The JWS Protected Header MUST include the following fields: +// * “alg” (Algorithm) +// * This field MUST NOT contain “none” or a Message Authentication Code +// (MAC) algorithm (e.g. one in which the algorithm registry description +// mentions MAC/HMAC). +// * “nonce” (defined in Section 6.5) +// * “url” (defined in Section 6.4) +// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below +func (h *Handler) validateJWS(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + jws, err := jwsFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + if len(jws.Signatures) == 0 { + api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature"))) + return + } + if len(jws.Signatures) > 1 { + api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature"))) + return + } + + sig := jws.Signatures[0] + uh := sig.Unprotected + if len(uh.KeyID) > 0 || + uh.JSONWebKey != nil || + len(uh.Algorithm) > 0 || + len(uh.Nonce) > 0 || + len(uh.ExtraHeaders) > 0 { + api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used"))) + return + } + hdr := sig.Protected + switch hdr.Algorithm { + case jose.RS256, jose.RS384, jose.RS512: + if hdr.JSONWebKey != nil { + switch k := hdr.JSONWebKey.Key.(type) { + case *rsa.PublicKey: + if k.Size() < keys.MinRSAKeyBytes { + api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+ + "keys must be at least %d bits (%d bytes) in size", + 8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes))) + return + } + default: + api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match"))) + return + } + } + case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: + // we good + default: + api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm))) + return + } + + // Check the validity/freshness of the Nonce. + if err := h.Auth.UseNonce(hdr.Nonce); err != nil { + api.WriteError(w, err) + return + } + + // Check that the JWS url matches the requested url. + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + if !ok { + api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header"))) + return + } + reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} + if jwsURL != reqURL.String() { + api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))) + return + } + + if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { + api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive"))) + return + } + if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { + api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header"))) + return + } + next(w, r) + return + } +} + +// extractJWK is a middleware that extracts the JWK from the JWS and saves it +// in the context. Make sure to parse and validate the JWS before running this +// middleware. +func (h *Handler) extractJWK(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + jws, err := jwsFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + jwk := jws.Signatures[0].Protected.JSONWebKey + if jwk == nil { + api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header"))) + return + } + if !jwk.Valid() { + api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) + return + } + ctx = context.WithValue(ctx, jwkContextKey, jwk) + acc, err := h.Auth.GetAccountByKey(prov, jwk) + switch { + case nosql.IsErrNotFound(err): + // For NewAccount requests ... + break + case err != nil: + api.WriteError(w, err) + return + default: + if !acc.IsValid() { + api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + return + } + ctx = context.WithValue(ctx, accContextKey, acc) + } + next(w, r.WithContext(ctx)) + return + } +} + +// lookupProvisioner loads the provisioner associated with the request. +// Responsds 404 if the provisioner does not exist. +func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + name := chi.URLParam(r, "provisionerID") + provID, err := url.PathUnescape(name) + if err != nil { + api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name))) + return + } + p, err := h.Auth.LoadProvisionerByID("acme/" + provID) + if err != nil { + api.WriteError(w, err) + return + } + if p.GetType() != provisioner.TypeACME { + api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) + return + } + ctx = context.WithValue(ctx, provisionerContextKey, p) + next(w, r.WithContext(ctx)) + return + } +} + +// lookupJWK loads the JWK associated with the acme account referenced by the +// kid parameter of the signed payload. +// Make sure to parse and validate the JWS before running this middleware. +func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + jws, err := jwsFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + + kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "") + kid := jws.Signatures[0].Protected.KeyID + if !strings.HasPrefix(kid, kidPrefix) { + api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ + "required prefix; expected %s, but got %s", kidPrefix, kid))) + return + } + + accID := strings.TrimPrefix(kid, kidPrefix) + acc, err := h.Auth.GetAccount(prov, accID) + switch { + case nosql.IsErrNotFound(err): + api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + return + case err != nil: + api.WriteError(w, err) + return + default: + if !acc.IsValid() { + api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + return + } + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, jwkContextKey, acc.Key) + next(w, r.WithContext(ctx)) + return + } + } +} + +// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. +// Make sure to parse and validate the JWS before running this middleware. +func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + jws, err := jwsFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + jwk, err := jwkFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { + api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match"))) + return + } + payload, err := jws.Verify(jwk) + if err != nil { + api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) + return + } + ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{ + value: payload, + isPostAsGet: string(payload) == "", + isEmptyJSON: string(payload) == "{}", + }) + next(w, r.WithContext(ctx)) + return + } +} + +// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). +func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + payload, err := payloadFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + if !payload.isPostAsGet { + api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET"))) + return + } + next(w, r) + return + } +} diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go new file mode 100644 index 00000000..18fafd8d --- /dev/null +++ b/acme/api/middleware_test.go @@ -0,0 +1,1550 @@ +package api + +import ( + "bytes" + "context" + "crypto" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql/database" +) + +var testBody = []byte("foo") + +func testNext(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + return +} + +func TestHandlerAddNonce(t *testing.T) { + url := "https://ca.smallstep.com/acme/new-nonce" + type test struct { + auth acme.Interface + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/AddNonce-error": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{ + newNonce: func() (string, error) { + return "", acme.ServerInternalErr(errors.New("force")) + }, + }, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{ + newNonce: func() (string, error) { + return "bar", nil + }, + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + w := httptest.NewRecorder() + h.addNonce(testNext)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, res.Header["Replay-Nonce"], []string{"bar"}) + assert.Equals(t, res.Header["Cache-Control"], []string{"no-store"}) + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerAddDirLink(t *testing.T) { + url := "https://ca.smallstep.com/acme/new-nonce" + prov := newProv() + type test struct { + auth acme.Interface + link string + statusCode int + ctx context.Context + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "ok": func(t *testing.T) test { + link := "https://ca.smallstep.com/acme/directory" + return test{ + auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + return link + }, + }, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + link: link, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.addDirLink(testNext)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s>;rel=\"index\"", tc.link)}) + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerVerifyContentType(t *testing.T) { + prov := newProv() + url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/abc123", + acme.URLSafeProvisionerName(prov)) + type test struct { + h Handler + ctx context.Context + contentType string + problem *acme.Error + statusCode int + url string + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + h: Handler{Auth: &mockAcmeAuthority{}}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + h: Handler{Auth: &mockAcmeAuthority{}}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/general-bad-content-type": func(t *testing.T) test { + return test{ + h: Handler{ + Auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.CertificateLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, abs, false) + assert.Equals(t, in, []string{""}) + return "/certificate/" + }, + }, + }, + url: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", + acme.URLSafeProvisionerName(prov)), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + contentType: "foo", + statusCode: 400, + problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")), + } + }, + "fail/certificate-bad-content-type": func(t *testing.T) test { + return test{ + h: Handler{ + Auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.CertificateLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, abs, false) + assert.Equals(t, in, []string{""}) + return "/certificate/" + }, + }, + }, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + contentType: "foo", + statusCode: 400, + problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")), + } + }, + "ok": func(t *testing.T) test { + return test{ + h: Handler{ + Auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.CertificateLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, abs, false) + assert.Equals(t, in, []string{""}) + return "/certificate/" + }, + }, + }, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + contentType: "application/jose+json", + statusCode: 200, + } + }, + "ok/certificate/pkix-cert": func(t *testing.T) test { + return test{ + h: Handler{ + Auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.CertificateLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, abs, false) + assert.Equals(t, in, []string{""}) + return "/certificate/" + }, + }, + }, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + contentType: "application/pkix-cert", + statusCode: 200, + } + }, + "ok/certificate/jose+json": func(t *testing.T) test { + return test{ + h: Handler{ + Auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.CertificateLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, abs, false) + assert.Equals(t, in, []string{""}) + return "/certificate/" + }, + }, + }, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + contentType: "application/jose+json", + statusCode: 200, + } + }, + "ok/certificate/pkcs7-mime": func(t *testing.T) test { + return test{ + h: Handler{ + Auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.CertificateLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.Equals(t, abs, false) + assert.Equals(t, in, []string{""}) + return "/certificate/" + }, + }, + }, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + contentType: "application/pkcs7-mime", + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + _url := url + if tc.url != "" { + _url = tc.url + } + req := httptest.NewRequest("GET", _url, nil) + req = req.WithContext(tc.ctx) + req.Header.Add("Content-Type", tc.contentType) + w := httptest.NewRecorder() + tc.h.verifyContentType(testNext)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerIsPostAsGet(t *testing.T) { + url := "https://ca.smallstep.com/acme/new-account" + type test struct { + ctx context.Context + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-payload": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/nil-payload": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), payloadContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/not-post-as-get": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), + statusCode: 400, + problem: acme.MalformedErr(errors.New("expected POST-as-GET")), + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(nil).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.isPostAsGet(testNext)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +type errReader int + +func (errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("force") +} +func (errReader) Close() error { + return nil +} + +func TestHandlerParseJWS(t *testing.T) { + url := "https://ca.smallstep.com/acme/new-account" + type test struct { + next nextHTTP + body io.Reader + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/read-body-error": func(t *testing.T) test { + return test{ + body: errReader(0), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("failed to read request body: force")), + } + }, + "fail/parse-jws-error": func(t *testing.T) test { + return test{ + body: strings.NewReader("foo"), + statusCode: 400, + problem: acme.MalformedErr(errors.New("failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts")), + } + }, + "ok": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, new(jose.SignerOptions)) + assert.FatalError(t, err) + signed, err := signer.Sign([]byte("baz")) + assert.FatalError(t, err) + expRaw, err := signed.CompactSerialize() + assert.FatalError(t, err) + + return test{ + body: strings.NewReader(expRaw), + next: func(w http.ResponseWriter, r *http.Request) { + jws, err := jwsFromContext(r) + assert.FatalError(t, err) + gotRaw, err := jws.CompactSerialize() + assert.FatalError(t, err) + assert.Equals(t, gotRaw, expRaw) + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(nil).(*Handler) + req := httptest.NewRequest("GET", url, tc.body) + w := httptest.NewRecorder() + h.parseJWS(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + _pub := jwk.Public() + pub := &_pub + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign([]byte("baz")) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + url := "https://ca.smallstep.com/acme/account/1234" + type test struct { + ctx context.Context + next func(http.ResponseWriter, *http.Request) + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-jws": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/nil-jws": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/no-jwk": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + } + }, + "fail/nil-jwk": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + return test{ + ctx: context.WithValue(ctx, jwkContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + } + }, + "fail/verify-jws-failure": func(t *testing.T) test { + _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + _pub := _jwk.Public() + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, &_pub) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("error verifying jws: square/go-jose: error in cryptographic primitive")), + } + }, + "fail/algorithm-mismatch": func(t *testing.T) test { + _pub := *pub + clone := &_pub + clone.Algorithm = jose.HS256 + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, clone) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("verifier and signature algorithm do not match")), + } + }, + "ok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, pub) + return test{ + ctx: ctx, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + p, err := payloadFromContext(r) + assert.FatalError(t, err) + if assert.NotNil(t, p) { + assert.Equals(t, p.value, []byte("baz")) + assert.False(t, p.isPostAsGet) + assert.False(t, p.isEmptyJSON) + } + w.Write(testBody) + }, + } + }, + "ok/empty-algorithm-in-jwk": func(t *testing.T) test { + _pub := *pub + clone := &_pub + clone.Algorithm = "" + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, pub) + return test{ + ctx: ctx, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + p, err := payloadFromContext(r) + assert.FatalError(t, err) + if assert.NotNil(t, p) { + assert.Equals(t, p.value, []byte("baz")) + assert.False(t, p.isPostAsGet) + assert.False(t, p.isEmptyJSON) + } + w.Write(testBody) + }, + } + }, + "ok/post-as-get": func(t *testing.T) test { + _jws, err := signer.Sign([]byte("")) + assert.FatalError(t, err) + _raw, err := _jws.CompactSerialize() + assert.FatalError(t, err) + _parsed, err := jose.ParseJWS(_raw) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwkContextKey, pub) + return test{ + ctx: ctx, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + p, err := payloadFromContext(r) + assert.FatalError(t, err) + if assert.NotNil(t, p) { + assert.Equals(t, p.value, []byte{}) + assert.True(t, p.isPostAsGet) + assert.False(t, p.isEmptyJSON) + } + w.Write(testBody) + }, + } + }, + "ok/empty-json": func(t *testing.T) test { + _jws, err := signer.Sign([]byte("{}")) + assert.FatalError(t, err) + _raw, err := _jws.CompactSerialize() + assert.FatalError(t, err) + _parsed, err := jose.ParseJWS(_raw) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwkContextKey, pub) + return test{ + ctx: ctx, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + p, err := payloadFromContext(r) + assert.FatalError(t, err) + if assert.NotNil(t, p) { + assert.Equals(t, p.value, []byte("{}")) + assert.False(t, p.isPostAsGet) + assert.True(t, p.isEmptyJSON) + } + w.Write(testBody) + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(nil).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.verifyAndExtractJWSPayload(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerLookupJWK(t *testing.T) { + prov := newProv() + url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", + acme.URLSafeProvisionerName(prov)) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + accID := "account-id" + prefix := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", + acme.URLSafeProvisionerName(prov)) + so := new(jose.SignerOptions) + so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign([]byte("baz")) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + type test struct { + auth acme.Interface + ctx context.Context + next func(http.ResponseWriter, *http.Request) + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-jws": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/nil-jws": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/no-kid": func(t *testing.T) test { + _signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, new(jose.SignerOptions)) + assert.FatalError(t, err) + _jws, err := _signer.Sign([]byte("baz")) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) + return test{ + auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{""}) + return prefix + }, + }, + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got ", prefix)), + } + }, + "fail/bad-kid-prefix": func(t *testing.T) test { + _so := new(jose.SignerOptions) + _so.WithHeader("kid", "foo") + _signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, _so) + assert.FatalError(t, err) + _jws, err := _signer.Sign([]byte("baz")) + assert.FatalError(t, err) + _raw, err := _jws.CompactSerialize() + assert.FatalError(t, err) + _parsed, err := jose.ParseJWS(_raw) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _parsed) + return test{ + auth: &mockAcmeAuthority{ + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{""}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + }, + }, + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got foo", prefix)), + } + }, + "fail/account-not-found": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + auth: &mockAcmeAuthority{ + getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, accID) + return nil, database.ErrNotFound + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{""}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + }, + }, + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/GetAccount-error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + auth: &mockAcmeAuthority{ + getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, accID) + return nil, acme.ServerInternalErr(errors.New("force")) + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{""}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + }, + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "fail/account-not-valid": func(t *testing.T) test { + acc := &acme.Account{Status: "deactivated"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + auth: &mockAcmeAuthority{ + getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, accID) + return acc, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{""}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + }, + }, + ctx: ctx, + statusCode: 401, + problem: acme.UnauthorizedErr(errors.New("account is not active")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{Status: "valid", Key: jwk} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + auth: &mockAcmeAuthority{ + getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, accID) + return acc, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{""}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + }, + }, + ctx: ctx, + next: func(w http.ResponseWriter, r *http.Request) { + _acc, err := accountFromContext(r) + assert.FatalError(t, err) + assert.Equals(t, _acc, acc) + _jwk, err := jwkFromContext(r) + assert.FatalError(t, err) + assert.Equals(t, _jwk, jwk) + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.lookupJWK(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerExtractJWK(t *testing.T) { + prov := newProv() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + kid, err := jwk.Thumbprint(crypto.SHA256) + assert.FatalError(t, err) + pub := jwk.Public() + pub.KeyID = base64.RawURLEncoding.EncodeToString(kid) + + so := new(jose.SignerOptions) + so.WithHeader("jwk", pub) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign([]byte("baz")) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", + acme.URLSafeProvisionerName(prov)) + type test struct { + auth acme.Interface + ctx context.Context + next func(http.ResponseWriter, *http.Request) + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-jws": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/nil-jws": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/nil-jwk": func(t *testing.T) test { + _jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + JSONWebKey: nil, + }, + }, + }, + } + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("jwk expected in protected header")), + } + }, + "fail/invalid-jwk": func(t *testing.T) test { + _jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + JSONWebKey: &jose.JSONWebKey{Key: "foo"}, + }, + }, + }, + } + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("invalid jwk in protected header")), + } + }, + "fail/GetAccountByKey-error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + ctx: ctx, + auth: &mockAcmeAuthority{ + getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, jwk.KeyID, pub.KeyID) + return nil, acme.ServerInternalErr(errors.New("force")) + }, + }, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "fail/account-not-valid": func(t *testing.T) test { + acc := &acme.Account{Status: "deactivated"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + ctx: ctx, + auth: &mockAcmeAuthority{ + getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, jwk.KeyID, pub.KeyID) + return acc, nil + }, + }, + statusCode: 401, + problem: acme.UnauthorizedErr(errors.New("account is not active")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{Status: "valid"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + ctx: ctx, + auth: &mockAcmeAuthority{ + getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, jwk.KeyID, pub.KeyID) + return acc, nil + }, + }, + next: func(w http.ResponseWriter, r *http.Request) { + _acc, err := accountFromContext(r) + assert.FatalError(t, err) + assert.Equals(t, _acc, acc) + _jwk, err := jwkFromContext(r) + assert.FatalError(t, err) + assert.Equals(t, _jwk.KeyID, pub.KeyID) + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + "ok/no-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + ctx: ctx, + auth: &mockAcmeAuthority{ + getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + assert.Equals(t, p, prov) + assert.Equals(t, jwk.KeyID, pub.KeyID) + return nil, database.ErrNotFound + }, + }, + next: func(w http.ResponseWriter, r *http.Request) { + _acc, err := accountFromContext(r) + assert.NotNil(t, err) + assert.Nil(t, _acc) + _jwk, err := jwkFromContext(r) + assert.FatalError(t, err) + assert.Equals(t, _jwk.KeyID, pub.KeyID) + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.extractJWK(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} + +func TestHandlerValidateJWS(t *testing.T) { + url := "https://ca.smallstep.com/acme/account/1234" + type test struct { + auth acme.Interface + ctx context.Context + next func(http.ResponseWriter, *http.Request) + problem *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-jws": func(t *testing.T) test { + return test{ + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/nil-jws": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + } + }, + "fail/no-signature": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), + statusCode: 400, + problem: acme.MalformedErr(errors.New("request body does not contain a signature")), + } + }, + "fail/more-than-one-signature": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + {}, + {}, + }, + } + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.New("request body contains more than one signature")), + } + }, + "fail/unprotected-header-not-empty": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + {Unprotected: jose.Header{Nonce: "abc"}}, + }, + } + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.New("unprotected header must not be used")), + } + }, + "fail/unsuitable-algorithm-none": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + {Protected: jose.Header{Algorithm: "none"}}, + }, + } + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")), + } + }, + "fail/unsuitable-algorithm-mac": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + {Protected: jose.Header{Algorithm: jose.HS256}}, + }, + } + return test{ + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), + } + }, + "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + pub := jwk.Public() + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.RS256, + JSONWebKey: &pub, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), + } + }, + "fail/rsa-key-too-small": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("RSA", "", "", "sig", "", 1024) + assert.FatalError(t, err) + pub := jwk.Public() + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.RS256, + JSONWebKey: &pub, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), + } + }, + "fail/UseNonce-error": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + {Protected: jose.Header{Algorithm: jose.ES256}}, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return acme.ServerInternalErr(errors.New("force")) + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "fail/no-url-header": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + {Protected: jose.Header{Algorithm: jose.ES256}}, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.New("jws missing url protected header")), + } + }, + "fail/url-mismatch": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": "foo", + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), + } + }, + "fail/both-jwk-kid": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + pub := jwk.Public() + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + KeyID: "bar", + JSONWebKey: &pub, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), + } + }, + "fail/no-jwk-kid": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + statusCode: 400, + problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), + } + }, + "ok/kid": func(t *testing.T) test { + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + KeyID: "bar", + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + "ok/jwk/ecdsa": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + pub := jwk.Public() + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + JSONWebKey: &pub, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + "ok/jwk/rsa": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("RSA", "", "", "sig", "", 2048) + assert.FatalError(t, err) + pub := jwk.Public() + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.RS256, + JSONWebKey: &pub, + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": url, + }, + }, + }, + }, + } + return test{ + auth: &mockAcmeAuthority{ + useNonce: func(n string) error { + return nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, jws), + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + return + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.validateJWS(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} diff --git a/acme/api/order.go b/acme/api/order.go new file mode 100644 index 00000000..83d1e26e --- /dev/null +++ b/acme/api/order.go @@ -0,0 +1,164 @@ +package api + +import ( + "crypto/x509" + "encoding/base64" + "encoding/json" + "net/http" + "time" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api" +) + +// NewOrderRequest represents the body for a NewOrder request. +type NewOrderRequest struct { + Identifiers []acme.Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` +} + +// Validate validates a new-order request body. +func (n *NewOrderRequest) Validate() error { + if len(n.Identifiers) == 0 { + return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")) + } + for _, id := range n.Identifiers { + if id.Type != "dns" { + return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type)) + } + } + return nil +} + +// FinalizeRequest captures the body for a Finalize order request. +type FinalizeRequest struct { + CSR string `json:"csr"` + csr *x509.CertificateRequest +} + +// Validate validates a finalize request body. +func (f *FinalizeRequest) Validate() error { + var err error + csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) + if err != nil { + return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr")) + } + f.csr, err = x509.ParseCertificateRequest(csrBytes) + if err != nil { + return acme.MalformedErr(errors.Wrap(err, "unable to parse csr")) + } + if err = f.csr.CheckSignature(); err != nil { + return acme.MalformedErr(errors.Wrap(err, "csr failed signature check")) + } + return nil +} + +// NewOrder ACME api for creating a new order. +func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + payload, err := payloadFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + var nor NewOrderRequest + if err := json.Unmarshal(payload.value, &nor); err != nil { + api.WriteError(w, acme.MalformedErr(errors.Wrap(err, + "failed to unmarshal new-order request payload"))) + return + } + if err := nor.Validate(); err != nil { + api.WriteError(w, err) + return + } + + o, err := h.Auth.NewOrder(prov, acme.OrderOptions{ + AccountID: acc.GetID(), + Identifiers: nor.Identifiers, + NotBefore: nor.NotBefore, + NotAfter: nor.NotAfter, + }) + if err != nil { + api.WriteError(w, err) + return + } + + w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID())) + api.JSONStatus(w, o, http.StatusCreated) + return +} + +// GetOrder ACME api for retrieving an order. +func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + oid := chi.URLParam(r, "ordID") + o, err := h.Auth.GetOrder(prov, acc.GetID(), oid) + if err != nil { + api.WriteError(w, err) + return + } + + w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID())) + api.JSON(w, o) + return +} + +// FinalizeOrder attemptst to finalize an order and create a certificate. +func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { + prov, err := provisionerFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + acc, err := accountFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + payload, err := payloadFromContext(r) + if err != nil { + api.WriteError(w, err) + return + } + var fr FinalizeRequest + if err := json.Unmarshal(payload.value, &fr); err != nil { + api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload"))) + return + } + if err := fr.Validate(); err != nil { + api.WriteError(w, err) + return + } + + oid := chi.URLParam(r, "ordID") + o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr) + if err != nil { + api.WriteError(w, err) + return + } + + w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID)) + api.JSON(w, o) + return +} diff --git a/acme/api/order_test.go b/acme/api/order_test.go new file mode 100644 index 00000000..68bd4f46 --- /dev/null +++ b/acme/api/order_test.go @@ -0,0 +1,757 @@ +package api + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/crypto/pemutil" +) + +func TestNewOrderRequestValidate(t *testing.T) { + type test struct { + nor *NewOrderRequest + nbf, naf time.Time + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-identifiers": func(t *testing.T) test { + return test{ + nor: &NewOrderRequest{}, + err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")), + } + }, + "fail/bad-identifier": func(t *testing.T) test { + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "foo", Value: "bar.com"}, + }, + }, + err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")), + } + }, + "ok": func(t *testing.T) test { + nbf := time.Now().UTC().Add(time.Minute) + naf := time.Now().UTC().Add(5 * time.Minute) + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "bar.com"}, + }, + NotAfter: naf, + NotBefore: nbf, + }, + nbf: nbf, + naf: naf, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + if err := tc.nor.Validate(); err != nil { + if assert.NotNil(t, err) { + ae, ok := err.(*acme.Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + if tc.nbf.IsZero() { + assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute))) + assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute))) + } else { + assert.Equals(t, tc.nor.NotBefore, tc.nbf) + } + if tc.naf.IsZero() { + assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour))) + assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute))) + } else { + assert.Equals(t, tc.nor.NotAfter, tc.naf) + } + } + } + }) + } +} + +func TestFinalizeRequestValidate(t *testing.T) { + _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") + assert.FatalError(t, err) + csr, ok := _csr.(*x509.CertificateRequest) + assert.Fatal(t, ok) + type test struct { + fr *FinalizeRequest + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/parse-csr-error": func(t *testing.T) test { + return test{ + fr: &FinalizeRequest{}, + err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")), + } + }, + "fail/invalid-csr-signature": func(t *testing.T) test { + b, err := pemutil.Read("../../authority/testdata/certs/badsig.csr") + assert.FatalError(t, err) + c, ok := b.(*x509.CertificateRequest) + assert.Fatal(t, ok) + return test{ + fr: &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(c.Raw), + }, + err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")), + } + }, + "ok": func(t *testing.T) test { + return test{ + fr: &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + if err := tc.fr.Validate(); err != nil { + if assert.NotNil(t, err) { + ae, ok := err.(*acme.Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.fr.csr.Raw, csr.Raw) + } + } + }) + } +} + +func TestHandlerGetOrder(t *testing.T) { + expiry := time.Now().UTC().Add(6 * time.Hour) + nbf := time.Now().UTC() + naf := time.Now().UTC().Add(24 * time.Hour) + o := acme.Order{ + ID: "orderID", + Expires: expiry.Format(time.RFC3339), + NotBefore: nbf.Format(time.RFC3339), + NotAfter: naf.Format(time.RFC3339), + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + Status: "pending", + Authorizations: []string{"foo", "bar"}, + } + + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("ordID", o.ID) + prov := newProv() + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s", + acme.URLSafeProvisionerName(prov), o.ID) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + auth: &mockAcmeAuthority{}, + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/getOrder-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.ServerInternalErr(errors.New("force")), + }, + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, o.ID) + return &o, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.OrderLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{o.ID}) + return url + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetOrder(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(o) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} + +func TestHandlerNewOrder(t *testing.T) { + expiry := time.Now().UTC().Add(6 * time.Hour) + nbf := time.Now().UTC().Add(5 * time.Hour) + naf := nbf.Add(17 * time.Hour) + o := acme.Order{ + ID: "orderID", + Expires: expiry.Format(time.RFC3339), + NotBefore: nbf.Format(time.RFC3339), + NotAfter: naf.Format(time.RFC3339), + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "bar.com"}, + }, + Status: "pending", + Authorizations: []string{"foo", "bar"}, + } + + prov := newProv() + url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", + acme.URLSafeProvisionerName(prov)) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/no-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/nil-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/unmarshal-payload-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")), + } + }, + "fail/malformed-payload-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{} + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")), + } + }, + "fail/NewOrder-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "bar.com"}, + }, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + assert.Equals(t, p, prov) + assert.Equals(t, ops.AccountID, acc.ID) + assert.Equals(t, ops.Identifiers, nor.Identifiers) + return nil, acme.MalformedErr(errors.New("force")) + }, + }, + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "bar.com"}, + }, + NotBefore: nbf, + NotAfter: naf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + assert.Equals(t, p, prov) + assert.Equals(t, ops.AccountID, acc.ID) + assert.Equals(t, ops.Identifiers, nor.Identifiers) + assert.Equals(t, ops.NotBefore, nbf) + assert.Equals(t, ops.NotAfter, naf) + return &o, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.OrderLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{o.ID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID) + }, + }, + ctx: ctx, + statusCode: 201, + } + }, + "ok/default-naf-nbf": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "bar.com"}, + }, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + auth: &mockAcmeAuthority{ + newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + assert.Equals(t, p, prov) + assert.Equals(t, ops.AccountID, acc.ID) + assert.Equals(t, ops.Identifiers, nor.Identifiers) + + assert.True(t, ops.NotBefore.IsZero()) + assert.True(t, ops.NotAfter.IsZero()) + return &o, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.OrderLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{o.ID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID) + }, + }, + ctx: ctx, + statusCode: 201, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.NewOrder(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(o) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Location"], + []string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} + +func TestHandlerFinalizeOrder(t *testing.T) { + expiry := time.Now().UTC().Add(6 * time.Hour) + nbf := time.Now().UTC().Add(5 * time.Hour) + naf := nbf.Add(17 * time.Hour) + o := acme.Order{ + ID: "orderID", + Expires: expiry.Format(time.RFC3339), + NotBefore: nbf.Format(time.RFC3339), + NotAfter: naf.Format(time.RFC3339), + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "bar.com"}, + }, + Status: "valid", + Authorizations: []string{"foo", "bar"}, + Certificate: "https://ca.smallstep.com/acme/certificate/certID", + } + _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") + assert.FatalError(t, err) + csr, ok := _csr.(*x509.CertificateRequest) + assert.Fatal(t, ok) + + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("ordID", o.ID) + prov := newProv() + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize", + acme.URLSafeProvisionerName(prov), o.ID) + + type test struct { + auth acme.Interface + ctx context.Context + statusCode int + problem *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.Background(), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, nil), + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, + "fail/no-account": func(t *testing.T) test { + return test{ + auth: &mockAcmeAuthority{}, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) + return test{ + auth: &mockAcmeAuthority{}, + ctx: ctx, + statusCode: 404, + problem: acme.AccountDoesNotExistErr(nil), + } + }, + "fail/no-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/nil-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + } + }, + "fail/unmarshal-payload-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")), + } + }, + "fail/malformed-payload-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + fr := &FinalizeRequest{} + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")), + } + }, + "fail/FinalizeOrder-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, o.ID) + assert.Equals(t, incsr.Raw, csr.Raw) + return nil, acme.MalformedErr(errors.New("force")) + }, + }, + ctx: ctx, + statusCode: 400, + problem: acme.MalformedErr(errors.New("force")), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + nor := &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { + assert.Equals(t, p, prov) + assert.Equals(t, accID, acc.ID) + assert.Equals(t, id, o.ID) + assert.Equals(t, incsr.Raw, csr.Raw) + return &o, nil + }, + getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + assert.Equals(t, typ, acme.OrderLink) + assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + assert.True(t, abs) + assert.Equals(t, in, []string{o.ID}) + return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", + acme.URLSafeProvisionerName(prov), o.ID) + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := New(tc.auth).(*Handler) + req := httptest.NewRequest("GET", url, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.FinalizeOrder(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + var ae acme.AError + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + prob := tc.problem.ToACME() + + assert.Equals(t, ae.Type, prob.Type) + assert.Equals(t, ae.Detail, prob.Detail) + assert.Equals(t, ae.Identifier, prob.Identifier) + assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + expB, err := json.Marshal(o) + assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Location"], + []string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", + acme.URLSafeProvisionerName(prov), o.ID)}) + assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) + } + }) + } +} diff --git a/acme/authority.go b/acme/authority.go new file mode 100644 index 00000000..a56d5ac2 --- /dev/null +++ b/acme/authority.go @@ -0,0 +1,263 @@ +package acme + +import ( + "crypto" + "crypto/x509" + "encoding/base64" + "net" + "net/http" + "net/url" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql" +) + +// Interface is the acme authority interface. +type Interface interface { + DeactivateAccount(provisioner.Interface, string) (*Account, error) + FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error) + GetAccount(provisioner.Interface, string) (*Account, error) + GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error) + GetAuthz(provisioner.Interface, string, string) (*Authz, error) + GetCertificate(string, string) ([]byte, error) + GetDirectory(provisioner.Interface) *Directory + GetLink(Link, string, bool, ...string) string + GetOrder(provisioner.Interface, string, string) (*Order, error) + GetOrdersByAccount(provisioner.Interface, string) ([]string, error) + LoadProvisionerByID(string) (provisioner.Interface, error) + NewAccount(provisioner.Interface, AccountOptions) (*Account, error) + NewNonce() (string, error) + NewOrder(provisioner.Interface, OrderOptions) (*Order, error) + UpdateAccount(provisioner.Interface, string, []string) (*Account, error) + UseNonce(string) error + ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error) +} + +// Authority is the layer that handles all ACME interactions. +type Authority struct { + db nosql.DB + dir *directory + signAuth SignAuthority +} + +// NewAuthority returns a new Authority that implements the ACME interface. +func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) *Authority { + return &Authority{ + db: db, dir: newDirectory(dns, prefix), signAuth: signAuth, + } +} + +// GetLink returns the requested link from the directory. +func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string { + return a.dir.getLink(typ, provID, abs, inputs...) +} + +// GetDirectory returns the ACME directory object. +func (a *Authority) GetDirectory(p provisioner.Interface) *Directory { + name := url.PathEscape(p.GetName()) + return &Directory{ + NewNonce: a.dir.getLink(NewNonceLink, name, true), + NewAccount: a.dir.getLink(NewAccountLink, name, true), + NewOrder: a.dir.getLink(NewOrderLink, name, true), + RevokeCert: a.dir.getLink(RevokeCertLink, name, true), + KeyChange: a.dir.getLink(KeyChangeLink, name, true), + } +} + +// LoadProvisionerByID calls out to the SignAuthority interface to load a +// provisioner by ID. +func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + return a.signAuth.LoadProvisionerByID(id) +} + +// NewNonce generates, stores, and returns a new ACME nonce. +func (a *Authority) NewNonce() (string, error) { + n, err := newNonce(a.db) + if err != nil { + return "", err + } + return n.ID, nil +} + +// UseNonce consumes the given nonce if it is valid, returns error otherwise. +func (a *Authority) UseNonce(nonce string) error { + return useNonce(a.db, nonce) +} + +// NewAccount creates, stores, and returns a new ACME account. +func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) { + acc, err := newAccount(a.db, ao) + if err != nil { + return nil, err + } + return acc.toACME(a.db, a.dir, p) +} + +// UpdateAccount updates an ACME account. +func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) { + acc, err := getAccountByID(a.db, id) + if err != nil { + return nil, ServerInternalErr(err) + } + if acc, err = acc.update(a.db, contact); err != nil { + return nil, err + } + return acc.toACME(a.db, a.dir, p) +} + +// GetAccount returns an ACME account. +func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) { + acc, err := getAccountByID(a.db, id) + if err != nil { + return nil, err + } + return acc.toACME(a.db, a.dir, p) +} + +// DeactivateAccount deactivates an ACME account. +func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) { + acc, err := getAccountByID(a.db, id) + if err != nil { + return nil, err + } + if acc, err = acc.deactivate(a.db); err != nil { + return nil, err + } + return acc.toACME(a.db, a.dir, p) +} + +func keyToID(jwk *jose.JSONWebKey) (string, error) { + kid, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint")) + } + return base64.RawURLEncoding.EncodeToString(kid), nil +} + +// GetAccountByKey returns the ACME associated with the jwk id. +func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) { + kid, err := keyToID(jwk) + if err != nil { + return nil, err + } + acc, err := getAccountByKeyID(a.db, kid) + if err != nil { + return nil, err + } + return acc.toACME(a.db, a.dir, p) +} + +// GetOrder returns an ACME order. +func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) { + o, err := getOrder(a.db, orderID) + if err != nil { + return nil, err + } + if accID != o.AccountID { + return nil, UnauthorizedErr(errors.New("account does not own order")) + } + if o, err = o.updateStatus(a.db); err != nil { + return nil, err + } + return o.toACME(a.db, a.dir, p) +} + +// GetOrdersByAccount returns the list of order urls owned by the account. +func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) { + oids, err := getOrderIDsByAccount(a.db, id) + if err != nil { + return nil, err + } + + var ret = []string{} + for _, oid := range oids { + o, err := getOrder(a.db, oid) + if err != nil { + return nil, ServerInternalErr(err) + } + if o.Status == StatusInvalid { + continue + } + ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID)) + } + return ret, nil +} + +// NewOrder generates, stores, and returns a new ACME order. +func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) { + order, err := newOrder(a.db, ops) + if err != nil { + return nil, Wrap(err, "error creating order") + } + return order.toACME(a.db, a.dir, p) +} + +// FinalizeOrder attempts to finalize an order and generate a new certificate. +func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { + o, err := getOrder(a.db, orderID) + if err != nil { + return nil, err + } + if accID != o.AccountID { + return nil, UnauthorizedErr(errors.New("account does not own order")) + } + o, err = o.finalize(a.db, csr, a.signAuth, p) + if err != nil { + return nil, Wrap(err, "error finalizing order") + } + return o.toACME(a.db, a.dir, p) +} + +// GetAuthz retrieves and attempts to update the status on an ACME authz +// before returning. +func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) { + authz, err := getAuthz(a.db, authzID) + if err != nil { + return nil, err + } + if accID != authz.getAccountID() { + return nil, UnauthorizedErr(errors.New("account does not own authz")) + } + authz, err = authz.updateStatus(a.db) + if err != nil { + return nil, Wrap(err, "error updating authz status") + } + return authz.toACME(a.db, a.dir, p) +} + +// ValidateChallenge attempts to validate the challenge. +func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { + ch, err := getChallenge(a.db, chID) + if err != nil { + return nil, err + } + if accID != ch.getAccountID() { + return nil, UnauthorizedErr(errors.New("account does not own challenge")) + } + client := http.Client{ + Timeout: time.Duration(30 * time.Second), + } + ch, err = ch.validate(a.db, jwk, validateOptions{ + httpGet: client.Get, + lookupTxt: net.LookupTXT, + }) + if err != nil { + return nil, Wrap(err, "error attempting challenge validation") + } + return ch.toACME(a.db, a.dir, p) +} + +// GetCertificate retrieves the Certificate by ID. +func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) { + cert, err := getCert(a.db, certID) + if err != nil { + return nil, err + } + if accID != cert.AccountID { + return nil, UnauthorizedErr(errors.New("account does not own certificate")) + } + return cert.toACME(a.db, a.dir) +} diff --git a/acme/authority_test.go b/acme/authority_test.go new file mode 100644 index 00000000..a5041fec --- /dev/null +++ b/acme/authority_test.go @@ -0,0 +1,1474 @@ +package acme + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql/database" +) + +func TestAuthorityGetLink(t *testing.T) { + auth := NewAuthority(nil, "ca.smallstep.com", "acme", nil) + provID := "acme-test-provisioner" + type test struct { + auth *Authority + typ Link + abs bool + inputs []string + res string + } + tests := map[string]func(t *testing.T) test{ + "ok/new-account/abs": func(t *testing.T) test { + return test{ + auth: auth, + typ: NewAccountLink, + abs: true, + res: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID), + } + }, + "ok/new-account/no-abs": func(t *testing.T) test { + return test{ + auth: auth, + typ: NewAccountLink, + abs: false, + res: fmt.Sprintf("/%s/new-account", provID), + } + }, + "ok/order/abs": func(t *testing.T) test { + return test{ + auth: auth, + typ: OrderLink, + abs: true, + inputs: []string{"foo"}, + res: fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/foo", provID), + } + }, + "ok/order/no-abs": func(t *testing.T) test { + return test{ + auth: auth, + typ: OrderLink, + abs: false, + inputs: []string{"foo"}, + res: fmt.Sprintf("/%s/order/foo", provID), + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + link := tc.auth.GetLink(tc.typ, provID, tc.abs, tc.inputs...) + assert.Equals(t, tc.res, link) + }) + } +} + +func TestAuthorityGetDirectory(t *testing.T) { + auth := NewAuthority(nil, "ca.smallstep.com", "acme", nil) + prov := newProv() + acmeDir := auth.GetDirectory(prov) + assert.Equals(t, acmeDir.NewNonce, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", URLSafeProvisionerName(prov))) + assert.Equals(t, acmeDir.NewAccount, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", URLSafeProvisionerName(prov))) + assert.Equals(t, acmeDir.NewOrder, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", URLSafeProvisionerName(prov))) + //assert.Equals(t, acmeDir.NewOrder, "httsp://ca.smallstep.com/acme/new-authz") + assert.Equals(t, acmeDir.RevokeCert, fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", URLSafeProvisionerName(prov))) + assert.Equals(t, acmeDir.KeyChange, fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", URLSafeProvisionerName(prov))) +} + +func TestAuthorityNewNonce(t *testing.T) { + type test struct { + auth *Authority + res *string + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/newNonce-error": func(t *testing.T) test { + auth := NewAuthority(&db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + res: nil, + err: ServerInternalErr(errors.New("error storing nonce: force")), + } + }, + "ok": func(t *testing.T) test { + var _res string + res := &_res + auth := NewAuthority(&db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + *res = string(key) + return nil, true, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + res: res, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if nonce, err := tc.auth.NewNonce(); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, nonce, *tc.res) + } + } + }) + } +} + +func TestAuthorityUseNonce(t *testing.T) { + type test struct { + auth *Authority + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/newNonce-error": func(t *testing.T) test { + auth := NewAuthority(&db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + return errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + err: ServerInternalErr(errors.New("error deleting nonce foo: force")), + } + }, + "ok": func(t *testing.T) test { + auth := NewAuthority(&db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + return nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.auth.UseNonce("foo"); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestAuthorityNewAccount(t *testing.T) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ops := AccountOptions{ + Key: jwk, Contact: []string{"foo", "bar"}, + } + prov := newProv() + type test struct { + auth *Authority + ops AccountOptions + err *Error + acc **Account + } + tests := map[string]func(t *testing.T) test{ + "fail/newAccount-error": func(t *testing.T) test { + auth := NewAuthority(&db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + ops: ops, + err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), + } + }, + "ok": func(t *testing.T) test { + var ( + _acmeacc = &Account{} + acmeacc = &_acmeacc + count = 0 + dir = newDirectory("ca.smallstep.com", "acme") + ) + auth := NewAuthority(&db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 1 { + var acc *account + assert.FatalError(t, json.Unmarshal(newval, &acc)) + *acmeacc, err = acc.toACME(nil, dir, prov) + return nil, true, nil + } + count++ + return nil, true, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + ops: ops, + acc: acmeacc, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeAcc, err := tc.auth.NewAccount(prov, tc.ops); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeAcc) + assert.FatalError(t, err) + expb, err := json.Marshal(*tc.acc) + assert.FatalError(t, err) + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityGetAccount(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id string + err *Error + acc *account + } + tests := map[string]func(t *testing.T) test{ + "fail/getAccount-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: acc.ID, + acc: acc, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeAcc, err := tc.auth.GetAccount(prov, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeAcc) + assert.FatalError(t, err) + + acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityGetAccountByKey(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + jwk *jose.JSONWebKey + err *Error + acc *account + } + tests := map[string]func(t *testing.T) test{ + "fail/generate-thumbprint-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + auth := NewAuthority(nil, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + jwk: jwk, + err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), + } + }, + "fail/getAccount-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + kid, err := keyToID(jwk) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + jwk: jwk, + err: ServerInternalErr(errors.New("error loading key-account index: force")), + } + }, + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + count := 0 + kid, err := keyToID(acc.Key) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch { + case count == 0: + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, key, []byte(kid)) + ret = []byte(acc.ID) + case count == 1: + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + ret = b + } + count++ + return ret, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + jwk: acc.Key, + acc: acc, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeAcc, err := tc.auth.GetAccountByKey(prov, tc.jwk); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeAcc) + assert.FatalError(t, err) + + acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityGetOrder(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id, accID string + err *Error + o *order + } + tests := map[string]func(t *testing.T) test{ + "fail/getOrder-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.New("error loading order foo: force")), + } + }, + "fail/order-not-owned-by-account": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: o.ID, + accID: "foo", + err: UnauthorizedErr(errors.New("account does not own order")), + } + }, + "fail/updateStatus-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + i := 0 + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch { + case i == 0: + i++ + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + default: + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(o.Authorizations[0])) + return nil, ServerInternalErr(errors.New("force")) + } + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: o.ID, + accID: o.AccountID, + err: ServerInternalErr(errors.Errorf("error loading authz %s: force", o.Authorizations[0])), + } + }, + "ok": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = "valid" + b, err := json.Marshal(o) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: o.ID, + accID: o.AccountID, + o: o, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeO, err := tc.auth.GetOrder(prov, tc.accID, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeO) + assert.FatalError(t, err) + + acmeExp, err := tc.o.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityGetCertificate(t *testing.T) { + type test struct { + auth *Authority + id, accID string + err *Error + cert *certificate + } + tests := map[string]func(t *testing.T) test{ + "fail/getCertificate-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.New("error loading certificate: force")), + } + }, + "fail/certificate-not-owned-by-account": func(t *testing.T) test { + cert, err := newcert() + assert.FatalError(t, err) + b, err := json.Marshal(cert) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: cert.ID, + accID: "foo", + err: UnauthorizedErr(errors.New("account does not own certificate")), + } + }, + "ok": func(t *testing.T) test { + cert, err := newcert() + assert.FatalError(t, err) + b, err := json.Marshal(cert) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: cert.ID, + accID: cert.AccountID, + cert: cert, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeCert, err := tc.auth.GetCertificate(tc.accID, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeCert) + assert.FatalError(t, err) + + acmeExp, err := tc.cert.toACME(nil, tc.auth.dir) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityGetAuthz(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id, accID string + err *Error + acmeAz *Authz + } + tests := map[string]func(t *testing.T) test{ + "fail/getAuthz-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.Errorf("error loading authz %s: force", id)), + } + }, + "fail/authz-not-owned-by-account": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(az.getID())) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: az.getID(), + accID: "foo", + err: UnauthorizedErr(errors.New("account does not own authz")), + } + }, + "fail/update-status-error": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + count := 0 + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(az.getID())) + ret = b + case 1: + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(az.getChallenges()[0])) + return nil, errors.New("force") + } + count++ + return ret, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: az.getID(), + accID: az.getAccountID(), + err: ServerInternalErr(errors.New("error updating authz status: error loading challenge")), + } + }, + "ok": func(t *testing.T) test { + var ch1B, ch2B = &[]byte{}, &[]byte{} + count := 0 + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + switch count { + case 0: + *ch1B = newval + case 1: + *ch2B = newval + } + count++ + return nil, true, nil + }, + } + az, err := newAuthz(mockdb, "1234", Identifier{ + Type: "dns", Value: "acme.example.com", + }) + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Status = StatusValid + b, err := json.Marshal(az) + assert.FatalError(t, err) + + ch1, err := unmarshalChallenge(*ch1B) + assert.FatalError(t, err) + ch2, err := unmarshalChallenge(*ch2B) + assert.FatalError(t, err) + count = 0 + mockdb = &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch1.getID())) + ret = *ch1B + case 1: + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch2.getID())) + ret = *ch2B + } + count++ + return ret, nil + }, + } + acmeAz, err := az.toACME(mockdb, newDirectory("ca.smallstep.com", "acme"), prov) + assert.FatalError(t, err) + + count = 0 + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(az.getID())) + ret = b + case 1: + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch1.getID())) + ret = *ch1B + case 2: + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch2.getID())) + ret = *ch2B + } + count++ + return ret, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: az.getID(), + accID: az.getAccountID(), + acmeAz: acmeAz, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeAz, err := tc.auth.GetAuthz(prov, tc.accID, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeAz) + assert.FatalError(t, err) + + expb, err := json.Marshal(tc.acmeAz) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityNewOrder(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + ops OrderOptions + err *Error + o **Order + } + tests := map[string]func(t *testing.T) test{ + "fail/newOrder-error": func(t *testing.T) test { + auth := NewAuthority(&db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + ops: defaultOrderOps(), + err: ServerInternalErr(errors.New("error creating order: error creating http challenge: error saving acme challenge: force")), + } + }, + "ok": func(t *testing.T) test { + var ( + _acmeO = &Order{} + acmeO = &_acmeO + count = 0 + dir = newDirectory("ca.smallstep.com", "acme") + err error + _accID string + accID = &_accID + ) + auth := NewAuthority(&db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + switch count { + case 0: + assert.Equals(t, bucket, challengeTable) + case 1: + assert.Equals(t, bucket, challengeTable) + case 2: + assert.Equals(t, bucket, authzTable) + case 3: + assert.Equals(t, bucket, challengeTable) + case 4: + assert.Equals(t, bucket, challengeTable) + case 5: + assert.Equals(t, bucket, authzTable) + case 6: + assert.Equals(t, bucket, orderTable) + var o order + assert.FatalError(t, json.Unmarshal(newval, &o)) + *acmeO, err = o.toACME(nil, dir, prov) + assert.FatalError(t, err) + *accID = o.AccountID + case 7: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, string(key), *accID) + } + count++ + return nil, true, nil + }, + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + ops: defaultOrderOps(), + o: acmeO, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeO, err := tc.auth.NewOrder(prov, tc.ops); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeO) + assert.FatalError(t, err) + expb, err := json.Marshal(*tc.o) + assert.FatalError(t, err) + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityGetOrdersByAccount(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id string + err *Error + res []string + } + tests := map[string]func(t *testing.T) test{ + "fail/getOrderIDsByAccount-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), + } + }, + "fail/getOrder-error": func(t *testing.T) test { + var ( + id = "zap" + oids = []string{"foo", "bar"} + count = 0 + err error + ) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(id)) + ret, err = json.Marshal(oids) + assert.FatalError(t, err) + case 1: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(oids[0])) + return nil, errors.New("force") + } + count++ + return ret, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.New("error loading order foo: force")), + } + }, + "ok": func(t *testing.T) test { + var ( + id = "zap" + count = 0 + err error + ) + foo, err := newO() + bar, err := newO() + baz, err := newO() + bar.Status = StatusInvalid + + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(id)) + ret, err = json.Marshal([]string{foo.ID, bar.ID, baz.ID}) + assert.FatalError(t, err) + case 1: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(foo.ID)) + ret, err = json.Marshal(foo) + assert.FatalError(t, err) + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(bar.ID)) + ret, err = json.Marshal(bar) + assert.FatalError(t, err) + case 3: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(baz.ID)) + ret, err = json.Marshal(baz) + assert.FatalError(t, err) + } + count++ + return ret, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + res: []string{ + fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", URLSafeProvisionerName(prov), foo.ID), + fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", URLSafeProvisionerName(prov), baz.ID), + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if orderLinks, err := tc.auth.GetOrdersByAccount(prov, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res, orderLinks) + } + } + }) + } +} + +func TestAuthorityFinalizeOrder(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id, accID string + err *Error + o *order + } + tests := map[string]func(t *testing.T) test{ + "fail/getOrder-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.New("error loading order foo: force")), + } + }, + "fail/order-not-owned-by-account": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: o.ID, + accID: "foo", + err: UnauthorizedErr(errors.New("account does not own order")), + } + }, + "fail/finalize-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Expires = time.Now().Add(-time.Minute) + b, err := json.Marshal(o) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: o.ID, + accID: o.AccountID, + err: ServerInternalErr(errors.New("error finalizing order: error storing order: force")), + } + }, + "ok": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusValid + o.Certificate = "certID" + b, err := json.Marshal(o) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: o.ID, + accID: o.AccountID, + o: o, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeO, err := tc.auth.FinalizeOrder(prov, tc.accID, tc.id, nil); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeO) + assert.FatalError(t, err) + + acmeExp, err := tc.o.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityValidateChallenge(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id, accID string + err *Error + ch challenge + } + tests := map[string]func(t *testing.T) test{ + "fail/getChallenge-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", id)), + } + }, + "fail/challenge-not-owned-by-account": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + b, err := json.Marshal(ch) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: ch.getID(), + accID: "foo", + err: UnauthorizedErr(errors.New("account does not own challenge")), + } + }, + "fail/validate-error": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + b, err := json.Marshal(ch) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: ch.getID(), + accID: ch.getAccountID(), + err: ServerInternalErr(errors.New("error attempting challenge validation: error saving acme challenge: force")), + } + }, + "ok": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + _ch, ok := ch.(*http01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusValid + _ch.baseChallenge.Validated = clock.Now() + b, err := json.Marshal(ch) + assert.FatalError(t, err) + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + return b, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: ch.getID(), + accID: ch.getAccountID(), + ch: ch, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeCh, err := tc.auth.ValidateChallenge(prov, tc.accID, tc.id, nil); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeCh) + assert.FatalError(t, err) + + acmeExp, err := tc.ch.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityUpdateAccount(t *testing.T) { + contact := []string{"baz", "zap"} + prov := newProv() + type test struct { + auth *Authority + id string + contact []string + acc *account + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/getAccount-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + contact: contact, + err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), + } + }, + "fail/update-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: acc.ID, + contact: contact, + err: ServerInternalErr(errors.New("error storing account: force")), + } + }, + + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + + _acc := *acc + clone := &_acc + clone.Contact = contact + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + return nil, true, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: acc.ID, + contact: contact, + acc: clone, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeAcc, err := tc.auth.UpdateAccount(prov, tc.id, tc.contact); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeAcc) + assert.FatalError(t, err) + + acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} + +func TestAuthorityDeactivateAccount(t *testing.T) { + prov := newProv() + type test struct { + auth *Authority + id string + acc *account + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/getAccount-error": func(t *testing.T) test { + id := "foo" + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(id)) + return nil, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: id, + err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), + } + }, + "fail/deactivate-error": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: acc.ID, + err: ServerInternalErr(errors.New("error storing account: force")), + } + }, + + "ok": func(t *testing.T) test { + acc, err := newAcc() + assert.FatalError(t, err) + b, err := json.Marshal(acc) + assert.FatalError(t, err) + + _acc := *acc + clone := &_acc + clone.Status = StatusDeactivated + clone.Deactivated = clock.Now() + auth := NewAuthority(&db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, key, []byte(acc.ID)) + return nil, true, nil + }, + }, "ca.smallstep.com", "acme", nil) + return test{ + auth: auth, + id: acc.ID, + acc: clone, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if acmeAcc, err := tc.auth.DeactivateAccount(prov, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + gotb, err := json.Marshal(acmeAcc) + assert.FatalError(t, err) + + acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + assert.FatalError(t, err) + expb, err := json.Marshal(acmeExp) + assert.FatalError(t, err) + + assert.Equals(t, expb, gotb) + } + } + }) + } +} diff --git a/acme/authz.go b/acme/authz.go new file mode 100644 index 00000000..132977e1 --- /dev/null +++ b/acme/authz.go @@ -0,0 +1,344 @@ +package acme + +import ( + "encoding/json" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/nosql" +) + +var defaultExpiryDuration = time.Hour * 24 + +// Authz is a subset of the Authz type containing only those attributes +// required for responses in the ACME protocol. +type Authz struct { + Identifier Identifier `json:"identifier"` + Status string `json:"status"` + Expires string `json:"expires"` + Challenges []*Challenge `json:"challenges"` + Wildcard bool `json:"wildcard"` + ID string `json:"-"` +} + +// ToLog enables response logging. +func (a *Authz) ToLog() (interface{}, error) { + b, err := json.Marshal(a) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) + } + return string(b), nil +} + +// GetID returns the Authz ID. +func (a *Authz) GetID() string { + return a.ID +} + +// authz is the interface that the various authz types must implement. +type authz interface { + save(nosql.DB, authz) error + clone() *baseAuthz + getID() string + getAccountID() string + getType() string + getIdentifier() Identifier + getStatus() string + getExpiry() time.Time + getWildcard() bool + getChallenges() []string + getCreated() time.Time + updateStatus(db nosql.DB) (authz, error) + toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error) +} + +// baseAuthz is the base authz type that others build from. +type baseAuthz struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier Identifier `json:"identifier"` + Status string `json:"status"` + Expires time.Time `json:"expires"` + Challenges []string `json:"challenges"` + Wildcard bool `json:"wildcard"` + Created time.Time `json:"created"` + Error *Error `json:"error"` +} + +func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) { + id, err := randID() + if err != nil { + return nil, err + } + + now := clock.Now() + ba := &baseAuthz{ + ID: id, + AccountID: accID, + Status: StatusPending, + Created: now, + Expires: now.Add(defaultExpiryDuration), + Identifier: identifier, + } + + if strings.HasPrefix(identifier.Value, "*.") { + ba.Wildcard = true + ba.Identifier = Identifier{ + Value: strings.TrimPrefix(identifier.Value, "*."), + Type: identifier.Type, + } + } + + return ba, nil +} + +// getID returns the ID of the authz. +func (ba *baseAuthz) getID() string { + return ba.ID +} + +// getAccountID returns the Account ID that created the authz. +func (ba *baseAuthz) getAccountID() string { + return ba.AccountID +} + +// getType returns the type of the authz. +func (ba *baseAuthz) getType() string { + return ba.Identifier.Type +} + +// getIdentifier returns the identifier for the authz. +func (ba *baseAuthz) getIdentifier() Identifier { + return ba.Identifier +} + +// getStatus returns the status of the authz. +func (ba *baseAuthz) getStatus() string { + return ba.Status +} + +// getWildcard returns true if the authz identifier has a '*', false otherwise. +func (ba *baseAuthz) getWildcard() bool { + return ba.Wildcard +} + +// getChallenges returns the authz challenge IDs. +func (ba *baseAuthz) getChallenges() []string { + return ba.Challenges +} + +// getExpiry returns the expiration time of the authz. +func (ba *baseAuthz) getExpiry() time.Time { + return ba.Expires +} + +// getCreated returns the created time of the authz. +func (ba *baseAuthz) getCreated() time.Time { + return ba.Created +} + +// toACME converts the internal Authz type into the public acmeAuthz type for +// presentation in the ACME protocol. +func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) { + var chs = make([]*Challenge, len(ba.Challenges)) + for i, chID := range ba.Challenges { + ch, err := getChallenge(db, chID) + if err != nil { + return nil, err + } + chs[i], err = ch.toACME(db, dir, p) + if err != nil { + return nil, err + } + } + return &Authz{ + Identifier: ba.Identifier, + Status: ba.getStatus(), + Challenges: chs, + Wildcard: ba.getWildcard(), + Expires: ba.Expires.Format(time.RFC3339), + ID: ba.ID, + }, nil +} + +func (ba *baseAuthz) save(db nosql.DB, old authz) error { + var ( + err error + oldB, newB []byte + ) + if old == nil { + oldB = nil + } else { + if oldB, err = json.Marshal(old); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old authz")) + } + } + if newB, err = json.Marshal(ba); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling new authz")) + } + _, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error storing authz")) + case !swapped: + return ServerInternalErr(errors.Errorf("error storing authz; " + + "value has changed since last read")) + default: + return nil + } +} + +func (ba *baseAuthz) clone() *baseAuthz { + u := *ba + return &u +} + +func (ba *baseAuthz) storeAndReturnError(db nosql.DB, err *Error) error { + clone := ba.clone() + clone.Error = err + clone.save(db, ba) + return err +} + +func (ba *baseAuthz) parent() authz { + return &dnsAuthz{ba} +} + +// updateStatus attempts to update the status on a baseAuthz and stores the +// updating object if necessary. +func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) { + newAuthz := ba.clone() + + now := time.Now().UTC() + switch ba.Status { + case StatusInvalid: + return ba.parent(), nil + case StatusValid: + return ba.parent(), nil + case StatusPending: + // check expiry + if now.After(ba.Expires) { + newAuthz.Status = StatusInvalid + newAuthz.Error = MalformedErr(errors.New("authz has expired")) + break + } + + var isValid = false + for _, chID := range ba.Challenges { + ch, err := getChallenge(db, chID) + if err != nil { + return ba, err + } + if ch.getStatus() == StatusValid { + isValid = true + break + } + } + + if !isValid { + return ba.parent(), nil + } + newAuthz.Status = StatusValid + newAuthz.Error = nil + default: + return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) + } + + if err := newAuthz.save(db, ba); err != nil { + return ba, err + } + return newAuthz.parent(), nil +} + +// unmarshalAuthz unmarshals an authz type into the correct sub-type. +func unmarshalAuthz(data []byte) (authz, error) { + var getType struct { + Identifier Identifier `json:"identifier"` + } + if err := json.Unmarshal(data, &getType); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type")) + } + + switch getType.Identifier.Type { + case "dns": + var ba baseAuthz + if err := json.Unmarshal(data, &ba); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz")) + } + return &dnsAuthz{&ba}, nil + default: + return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s", + getType.Identifier.Type)) + } +} + +// dnsAuthz represents a dns acme authorization. +type dnsAuthz struct { + *baseAuthz +} + +// newAuthz returns a new acme authorization object based on the identifier +// type. +func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) { + switch identifier.Type { + case "dns": + a, err = newDNSAuthz(db, accID, identifier) + default: + err = MalformedErr(errors.Errorf("unexpected authz type %s", + identifier.Type)) + } + return +} + +// newDNSAuthz returns a new dns acme authorization object. +func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) { + ba, err := newBaseAuthz(accID, identifier) + if err != nil { + return nil, err + } + + ba.Challenges = []string{} + if !ba.Wildcard { + // http challenges are only permitted if the DNS is not a wildcard dns. + ch1, err := newHTTP01Challenge(db, ChallengeOptions{ + AccountID: accID, + AuthzID: ba.ID, + Identifier: ba.Identifier}) + if err != nil { + return nil, Wrap(err, "error creating http challenge") + } + ba.Challenges = append(ba.Challenges, ch1.getID()) + } + ch2, err := newDNS01Challenge(db, ChallengeOptions{ + AccountID: accID, + AuthzID: ba.ID, + Identifier: identifier}) + if err != nil { + return nil, Wrap(err, "error creating dns challenge") + } + ba.Challenges = append(ba.Challenges, ch2.getID()) + + da := &dnsAuthz{ba} + if err := da.save(db, nil); err != nil { + return nil, err + } + + return da, nil +} + +// getAuthz retrieves and unmarshals an ACME authz type from the database. +func getAuthz(db nosql.DB, id string) (authz, error) { + b, err := db.Get(authzTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id)) + } + az, err := unmarshalAuthz(b) + if err != nil { + return nil, err + } + return az, nil +} diff --git a/acme/authz_test.go b/acme/authz_test.go new file mode 100644 index 00000000..96213e4f --- /dev/null +++ b/acme/authz_test.go @@ -0,0 +1,809 @@ +package acme + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +func newAz() (authz, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newAuthz(mockdb, "1234", Identifier{ + Type: "dns", Value: "acme.example.com", + }) +} + +func TestGetAuthz(t *testing.T) { + type test struct { + id string + db nosql.DB + az authz + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + return test{ + az: az, + id: az.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())), + } + }, + "fail/db-error": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + return test{ + az: az, + id: az.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Identifier.Type = "foo" + b, err := json.Marshal(az) + assert.FatalError(t, err) + return test{ + az: az, + id: az.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(az.getID())) + return b, nil + }, + }, + err: ServerInternalErr(errors.New("unexpected authz type foo")), + } + }, + "ok": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + return test{ + az: az, + id: az.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(az.getID())) + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if az, err := getAuthz(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.az.getID(), az.getID()) + assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) + assert.Equals(t, tc.az.getStatus(), az.getStatus()) + assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier()) + assert.Equals(t, tc.az.getCreated(), az.getCreated()) + assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) + assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) + } + } + }) + } +} + +func TestAuthzClone(t *testing.T) { + az, err := newAz() + assert.FatalError(t, err) + + clone := az.clone() + + assert.Equals(t, clone.getID(), az.getID()) + assert.Equals(t, clone.getAccountID(), az.getAccountID()) + assert.Equals(t, clone.getStatus(), az.getStatus()) + assert.Equals(t, clone.getIdentifier(), az.getIdentifier()) + assert.Equals(t, clone.getExpiry(), az.getExpiry()) + assert.Equals(t, clone.getCreated(), az.getCreated()) + assert.Equals(t, clone.getChallenges(), az.getChallenges()) + + clone.Status = StatusValid + + assert.NotEquals(t, clone.getStatus(), az.getStatus()) +} + +func TestNewAuthz(t *testing.T) { + iden := Identifier{ + Type: "dns", Value: "acme.example.com", + } + accID := "1234" + type test struct { + iden Identifier + db nosql.DB + err *Error + resChs *([]string) + } + tests := map[string]func(t *testing.T) test{ + "fail/unexpected-type": func(t *testing.T) test { + return test{ + iden: Identifier{Type: "foo", Value: "acme.example.com"}, + err: MalformedErr(errors.New("unexpected authz type foo")), + } + }, + "fail/new-http-chall-error": func(t *testing.T) test { + return test{ + iden: iden, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")), + } + }, + "fail/new-dns-chall-error": func(t *testing.T) test { + count := 0 + return test{ + iden: iden, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 1 { + return nil, false, errors.New("force") + } + count++ + return nil, true, nil + }, + }, + err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")), + } + }, + "fail/save-authz-error": func(t *testing.T) test { + count := 0 + return test{ + iden: iden, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 2 { + return nil, false, errors.New("force") + } + count++ + return nil, true, nil + }, + }, + err: ServerInternalErr(errors.New("error storing authz: force")), + } + }, + "ok": func(t *testing.T) test { + chs := &([]string{}) + count := 0 + return test{ + iden: iden, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 2 { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, old, nil) + + az, err := unmarshalAuthz(newval) + assert.FatalError(t, err) + + assert.Equals(t, az.getID(), string(key)) + assert.Equals(t, az.getAccountID(), accID) + assert.Equals(t, az.getStatus(), StatusPending) + assert.Equals(t, az.getIdentifier(), iden) + assert.Equals(t, az.getWildcard(), false) + + *chs = az.getChallenges() + + assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + expiry := az.getCreated().Add(defaultExpiryDuration) + assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) + assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) + } + count++ + return nil, true, nil + }, + }, + resChs: chs, + } + }, + "ok/wildcard": func(t *testing.T) test { + chs := &([]string{}) + count := 0 + _iden := Identifier{Type: "dns", Value: "*.acme.example.com"} + return test{ + iden: _iden, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 1 { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, old, nil) + + az, err := unmarshalAuthz(newval) + assert.FatalError(t, err) + + assert.Equals(t, az.getID(), string(key)) + assert.Equals(t, az.getAccountID(), accID) + assert.Equals(t, az.getStatus(), StatusPending) + assert.Equals(t, az.getIdentifier(), iden) + assert.Equals(t, az.getWildcard(), true) + + *chs = az.getChallenges() + // Verify that we only have 1 challenge instead of 2. + assert.True(t, len(*chs) == 1) + + assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + expiry := az.getCreated().Add(defaultExpiryDuration) + assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) + assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) + } + count++ + return nil, true, nil + }, + }, + resChs: chs, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + az, err := newAuthz(tc.db, accID, tc.iden) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, az.getAccountID(), accID) + assert.Equals(t, az.getType(), "dns") + assert.Equals(t, az.getStatus(), StatusPending) + + assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + expiry := az.getCreated().Add(defaultExpiryDuration) + assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) + assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) + + assert.Equals(t, az.getChallenges(), *(tc.resChs)) + + if strings.HasPrefix(tc.iden.Value, "*.") { + assert.True(t, az.getWildcard()) + assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*.")) + } else { + assert.False(t, az.getWildcard()) + assert.Equals(t, az.getIdentifier().Value, tc.iden.Value) + } + + assert.True(t, az.getID() != "") + } + } + }) + } +} + +func TestAuthzToACME(t *testing.T) { + dir := newDirectory("ca.smallstep.com", "acme") + + var ( + ch1, ch2 challenge + ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) + err error + ) + + count := 0 + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 0 { + *ch1Bytes = newval + ch1, err = unmarshalChallenge(newval) + assert.FatalError(t, err) + } else if count == 1 { + *ch2Bytes = newval + ch2, err = unmarshalChallenge(newval) + assert.FatalError(t, err) + } + count++ + return []byte("foo"), true, nil + }, + } + iden := Identifier{ + Type: "dns", Value: "acme.example.com", + } + az, err := newAuthz(mockdb, "1234", iden) + assert.FatalError(t, err) + prov := newProv() + + type test struct { + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/getChallenge1-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading challenge")), + } + }, + "fail/getChallenge2-error": func(t *testing.T) test { + count := 0 + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if count == 1 { + return nil, errors.New("force") + } + count++ + return *ch1Bytes, nil + }, + }, + err: ServerInternalErr(errors.New("error loading challenge")), + } + }, + "ok": func(t *testing.T) test { + count := 0 + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if count == 0 { + count++ + return *ch1Bytes, nil + } + return *ch2Bytes, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + acmeAz, err := az.toACME(tc.db, dir, prov) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acmeAz.ID, az.getID()) + assert.Equals(t, acmeAz.Identifier, iden) + assert.Equals(t, acmeAz.Status, StatusPending) + + acmeCh1, err := ch1.toACME(nil, dir, prov) + assert.FatalError(t, err) + acmeCh2, err := ch2.toACME(nil, dir, prov) + assert.FatalError(t, err) + + assert.Equals(t, acmeAz.Challenges[0], acmeCh1) + assert.Equals(t, acmeAz.Challenges[1], acmeCh2) + + expiry, err := time.Parse(time.RFC3339, acmeAz.Expires) + assert.FatalError(t, err) + assert.Equals(t, expiry.String(), az.getExpiry().String()) + } + } + }) + } +} + +func TestAuthzSave(t *testing.T) { + type test struct { + az, old authz + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + return test{ + az: az, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing authz: force")), + } + }, + "fail/old-nil/swap-false": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + return test{ + az: az, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil + }, + }, + err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")), + } + }, + "ok/old-nil": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + return test{ + az: az, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, authzTable) + assert.Equals(t, []byte(az.getID()), key) + return nil, true, nil + }, + }, + } + }, + "ok/old-not-nil": func(t *testing.T) test { + oldAz, err := newAz() + assert.FatalError(t, err) + az, err := newAz() + assert.FatalError(t, err) + + oldb, err := json.Marshal(oldAz) + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + return test{ + az: az, + old: oldAz, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, oldb) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, authzTable) + assert.Equals(t, []byte(az.getID()), key) + return []byte("foo"), true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.az.save(tc.db, tc.old); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestAuthzUnmarshal(t *testing.T) { + type test struct { + az authz + azb []byte + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/nil": func(t *testing.T) test { + return test{ + azb: nil, + err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")), + } + }, + "fail/unexpected-type": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Identifier.Type = "foo" + b, err := json.Marshal(az) + assert.FatalError(t, err) + return test{ + azb: b, + err: ServerInternalErr(errors.New("unexpected authz type foo")), + } + }, + "ok/dns": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + return test{ + az: az, + azb: b, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if az, err := unmarshalAuthz(tc.azb); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.az.getID(), az.getID()) + assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) + assert.Equals(t, tc.az.getStatus(), az.getStatus()) + assert.Equals(t, tc.az.getCreated(), az.getCreated()) + assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) + assert.Equals(t, tc.az.getWildcard(), az.getWildcard()) + assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) + } + } + }) + } +} + +func TestAuthzUpdateStatus(t *testing.T) { + type test struct { + az, res authz + err *Error + db nosql.DB + } + tests := map[string]func(t *testing.T) test{ + "fail/already-invalid": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Status = StatusInvalid + return test{ + az: az, + res: az, + } + }, + "fail/already-valid": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Status = StatusValid + return test{ + az: az, + res: az, + } + }, + "fail/unexpected-status": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Status = StatusReady + return test{ + az: az, + res: az, + err: ServerInternalErr(errors.New("unrecognized authz status: ready")), + } + }, + "fail/save-error": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) + return test{ + az: az, + res: az, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing authz: force")), + } + }, + "ok/expired": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) + + clone := az.clone() + clone.Error = MalformedErr(errors.New("authz has expired")) + clone.Status = StatusInvalid + return test{ + az: az, + res: clone.parent(), + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + }, + } + }, + "fail/get-challenge-error": func(t *testing.T) test { + az, err := newAz() + assert.FatalError(t, err) + + return test{ + az: az, + res: az, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading challenge")), + } + }, + "ok/valid": func(t *testing.T) test { + var ( + ch2 challenge + ch1Bytes = &([]byte{}) + err error + ) + + count := 0 + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 0 { + *ch1Bytes = newval + } else if count == 1 { + ch2, err = unmarshalChallenge(newval) + assert.FatalError(t, err) + } + count++ + return nil, true, nil + }, + } + iden := Identifier{ + Type: "dns", Value: "acme.example.com", + } + az, err := newAuthz(mockdb, "1234", iden) + assert.FatalError(t, err) + _az, ok := az.(*dnsAuthz) + assert.Fatal(t, ok) + _az.baseAuthz.Error = MalformedErr(nil) + + _ch, ok := ch2.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusValid + chb, err := json.Marshal(ch2) + + clone := az.clone() + clone.Status = StatusValid + clone.Error = nil + + count = 0 + return test{ + az: az, + res: clone.parent(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if count == 0 { + count++ + return *ch1Bytes, nil + } + count++ + return chb, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + }, + } + }, + "ok/still-pending": func(t *testing.T) test { + var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) + + count := 0 + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 0 { + *ch1Bytes = newval + } else if count == 1 { + *ch2Bytes = newval + } + count++ + return nil, true, nil + }, + } + iden := Identifier{ + Type: "dns", Value: "acme.example.com", + } + az, err := newAuthz(mockdb, "1234", iden) + assert.FatalError(t, err) + + count = 0 + return test{ + az: az, + res: az, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if count == 0 { + count++ + return *ch1Bytes, nil + } + count++ + return *ch2Bytes, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + az, err := tc.az.updateStatus(tc.db) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + expB, err := json.Marshal(tc.res) + assert.FatalError(t, err) + b, err := json.Marshal(az) + assert.FatalError(t, err) + assert.Equals(t, expB, b) + } + } + }) + } +} diff --git a/acme/certificate.go b/acme/certificate.go new file mode 100644 index 00000000..6a31c880 --- /dev/null +++ b/acme/certificate.go @@ -0,0 +1,89 @@ +package acme + +import ( + "crypto/x509" + "encoding/json" + "encoding/pem" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql" +) + +type certificate struct { + ID string `json:"id"` + Created time.Time `json:"created"` + AccountID string `json:"accountID"` + OrderID string `json:"orderID"` + Leaf []byte `json:"leaf"` + Intermediates []byte `json:"intermediates"` +} + +// CertOptions options with which to create and store a cert object. +type CertOptions struct { + AccountID string + OrderID string + Leaf *x509.Certificate + Intermediates []*x509.Certificate +} + +func newCert(db nosql.DB, ops CertOptions) (*certificate, error) { + id, err := randID() + if err != nil { + return nil, err + } + + leaf := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: ops.Leaf.Raw, + }) + var intermediates []byte + for _, cert := range ops.Intermediates { + intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })...) + } + + cert := &certificate{ + ID: id, + AccountID: ops.AccountID, + OrderID: ops.OrderID, + Leaf: leaf, + Intermediates: intermediates, + Created: time.Now().UTC(), + } + certB, err := json.Marshal(cert) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate")) + } + + _, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB) + switch { + case err != nil: + return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate")) + case !swapped: + return nil, ServerInternalErr(errors.New("error storing certificate; " + + "value has changed since last read")) + default: + return cert, nil + } +} + +func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) { + return append(c.Leaf, c.Intermediates...), nil +} + +func getCert(db nosql.DB, id string) (*certificate, error) { + b, err := db.Get(certTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate")) + } + var cert certificate + if err := json.Unmarshal(b, &cert); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate")) + } + return &cert, nil +} diff --git a/acme/certificate_test.go b/acme/certificate_test.go new file mode 100644 index 00000000..e99eb5af --- /dev/null +++ b/acme/certificate_test.go @@ -0,0 +1,253 @@ +package acme + +import ( + "crypto/x509" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +func defaultCertOps() (*CertOptions, error) { + crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt") + if err != nil { + return nil, err + } + inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt") + if err != nil { + return nil, err + } + root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt") + if err != nil { + return nil, err + } + return &CertOptions{ + AccountID: "accID", + OrderID: "ordID", + Leaf: crt, + Intermediates: []*x509.Certificate{inter, root}, + }, nil +} + +func newcert() (*certificate, error) { + ops, err := defaultCertOps() + if err != nil { + return nil, err + } + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + } + return newCert(mockdb, *ops) +} + +func TestNewCert(t *testing.T) { + type test struct { + db nosql.DB + ops CertOptions + err *Error + id *string + } + tests := map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + ops, err := defaultCertOps() + assert.FatalError(t, err) + return test{ + ops: *ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, old, nil) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error storing certificate: force")), + } + }, + "fail/cmpAndSwap-false": func(t *testing.T) test { + ops, err := defaultCertOps() + assert.FatalError(t, err) + return test{ + ops: *ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, old, nil) + return nil, false, nil + }, + }, + err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")), + } + }, + "ok": func(t *testing.T) test { + ops, err := defaultCertOps() + assert.FatalError(t, err) + var _id string + id := &_id + return test{ + ops: *ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, old, nil) + *id = string(key) + return nil, true, nil + }, + }, + id: id, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if cert, err := newCert(tc.db, tc.ops); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, cert.ID, *tc.id) + assert.Equals(t, cert.AccountID, tc.ops.AccountID) + assert.Equals(t, cert.OrderID, tc.ops.OrderID) + + leaf := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: tc.ops.Leaf.Raw, + }) + var intermediates []byte + for _, cert := range tc.ops.Intermediates { + intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })...) + } + assert.Equals(t, cert.Leaf, leaf) + assert.Equals(t, cert.Intermediates, intermediates) + + assert.True(t, cert.Created.Before(time.Now().Add(time.Minute))) + assert.True(t, cert.Created.After(time.Now().Add(-time.Minute))) + } + } + }) + } +} + +func TestGetCert(t *testing.T) { + type test struct { + id string + db nosql.DB + cert *certificate + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + cert, err := newcert() + assert.FatalError(t, err) + return test{ + cert: cert, + id: cert.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)), + } + }, + "fail/db-error": func(t *testing.T) test { + cert, err := newcert() + assert.FatalError(t, err) + return test{ + cert: cert, + id: cert.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading certificate: force")), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + cert, err := newcert() + assert.FatalError(t, err) + return test{ + cert: cert, + id: cert.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + return nil, nil + }, + }, + err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")), + } + }, + "ok": func(t *testing.T) test { + cert, err := newcert() + assert.FatalError(t, err) + b, err := json.Marshal(cert) + assert.FatalError(t, err) + return test{ + cert: cert, + id: cert.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if cert, err := getCert(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.ID, cert.ID) + assert.Equals(t, tc.cert.AccountID, cert.AccountID) + assert.Equals(t, tc.cert.OrderID, cert.OrderID) + assert.Equals(t, tc.cert.Created, cert.Created) + assert.Equals(t, tc.cert.Leaf, cert.Leaf) + assert.Equals(t, tc.cert.Intermediates, cert.Intermediates) + } + } + }) + } +} + +func TestCertificateToACME(t *testing.T) { + cert, err := newcert() + assert.FatalError(t, err) + acmeCert, err := cert.toACME(nil, nil) + assert.FatalError(t, err) + assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert) +} diff --git a/acme/challenge.go b/acme/challenge.go new file mode 100644 index 00000000..4fa39668 --- /dev/null +++ b/acme/challenge.go @@ -0,0 +1,445 @@ +package acme + +import ( + "crypto" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql" +) + +// Challenge is a subset of the challenge type containing only those attributes +// required for responses in the ACME protocol. +type Challenge struct { + Type string `json:"type"` + Status string `json:"status"` + Token string `json:"token"` + Validated string `json:"validated,omitempty"` + URL string `json:"url"` + Error *AError `json:"error,omitempty"` + ID string `json:"-"` + AuthzID string `json:"-"` +} + +// ToLog enables response logging. +func (c *Challenge) ToLog() (interface{}, error) { + b, err := json.Marshal(c) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) + } + return string(b), nil +} + +// GetID returns the Challenge ID. +func (c *Challenge) GetID() string { + return c.ID +} + +// GetAuthzID returns the parent Authz ID that owns the Challenge. +func (c *Challenge) GetAuthzID() string { + return c.AuthzID +} + +type httpGetter func(string) (*http.Response, error) +type lookupTxt func(string) ([]string, error) + +type validateOptions struct { + httpGet httpGetter + lookupTxt lookupTxt +} + +// challenge is the interface ACME challenege types must implement. +type challenge interface { + save(db nosql.DB, swap challenge) error + validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error) + getType() string + getError() *AError + getValue() string + getStatus() string + getID() string + getAuthzID() string + getToken() string + clone() *baseChallenge + getAccountID() string + getValidated() time.Time + getCreated() time.Time + toACME(nosql.DB, *directory, provisioner.Interface) (*Challenge, error) +} + +// ChallengeOptions is the type used to created a new Challenge. +type ChallengeOptions struct { + AccountID string + AuthzID string + Identifier Identifier +} + +// baseChallenge is the base Challenge type that others build from. +type baseChallenge struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + AuthzID string `json:"authzID"` + Type string `json:"type"` + Status string `json:"status"` + Token string `json:"token"` + Value string `json:"value"` + Validated time.Time `json:"validated"` + Created time.Time `json:"created"` + Error *AError `json:"error"` +} + +func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) { + id, err := randID() + if err != nil { + return nil, Wrap(err, "error generating random id for ACME challenge") + } + token, err := randID() + if err != nil { + return nil, Wrap(err, "error generating token for ACME challenge") + } + + return &baseChallenge{ + ID: id, + AccountID: accountID, + AuthzID: authzID, + Status: StatusPending, + Token: token, + Created: clock.Now(), + }, nil +} + +// getID returns the id of the baseChallenge. +func (bc *baseChallenge) getID() string { + return bc.ID +} + +// getAuthzID returns the authz ID of the baseChallenge. +func (bc *baseChallenge) getAuthzID() string { + return bc.AuthzID +} + +// getAccountID returns the account id of the baseChallenge. +func (bc *baseChallenge) getAccountID() string { + return bc.AccountID +} + +// getType returns the type of the baseChallenge. +func (bc *baseChallenge) getType() string { + return bc.Type +} + +// getValue returns the type of the baseChallenge. +func (bc *baseChallenge) getValue() string { + return bc.Value +} + +// getStatus returns the status of the baseChallenge. +func (bc *baseChallenge) getStatus() string { + return bc.Status +} + +// getToken returns the token of the baseChallenge. +func (bc *baseChallenge) getToken() string { + return bc.Token +} + +// getValidated returns the validated time of the baseChallenge. +func (bc *baseChallenge) getValidated() time.Time { + return bc.Validated +} + +// getCreated returns the created time of the baseChallenge. +func (bc *baseChallenge) getCreated() time.Time { + return bc.Created +} + +// getCreated returns the created time of the baseChallenge. +func (bc *baseChallenge) getError() *AError { + return bc.Error +} + +// toACME converts the internal Challenge type into the public acmeChallenge +// type for presentation in the ACME protocol. +func (bc *baseChallenge) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Challenge, error) { + ac := &Challenge{ + Type: bc.getType(), + Status: bc.getStatus(), + Token: bc.getToken(), + URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()), + ID: bc.getID(), + AuthzID: bc.getAuthzID(), + } + if !bc.Validated.IsZero() { + ac.Validated = bc.Validated.Format(time.RFC3339) + } + if bc.Error != nil { + ac.Error = bc.Error + } + return ac, nil +} + +// save writes the challenge to disk. For new challenges 'old' should be nil, +// otherwise 'old' should be a pointer to the acme challenge as it was at the +// start of the request. This method will fail if the value currently found +// in the bucket/row does not match the value of 'old'. +func (bc *baseChallenge) save(db nosql.DB, old challenge) error { + newB, err := json.Marshal(bc) + if err != nil { + return ServerInternalErr(errors.Wrap(err, + "error marshaling new acme challenge")) + } + var oldB []byte + if old == nil { + oldB = nil + } else { + oldB, err = json.Marshal(old) + if err != nil { + return ServerInternalErr(errors.Wrap(err, + "error marshaling old acme challenge")) + } + } + + _, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error saving acme challenge")) + case !swapped: + return ServerInternalErr(errors.New("error saving acme challenge; " + + "acme challenge has changed since last read")) + default: + return nil + } +} + +func (bc *baseChallenge) clone() *baseChallenge { + u := *bc + return &u +} + +func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { + return nil, ServerInternalErr(errors.New("unimplemented")) +} + +func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error { + clone := bc.clone() + clone.Error = err.ToACME() + return clone.save(db, bc) +} + +// unmarshalChallenge unmarshals a challenge type into the correct sub-type. +func unmarshalChallenge(data []byte) (challenge, error) { + var getType struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &getType); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type")) + } + + switch getType.Type { + case "dns-01": + var bc baseChallenge + if err := json.Unmarshal(data, &bc); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ + "challenge type into dns01Challenge")) + } + return &dns01Challenge{&bc}, nil + case "http-01": + var bc baseChallenge + if err := json.Unmarshal(data, &bc); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ + "challenge type into http01Challenge")) + } + return &http01Challenge{&bc}, nil + default: + return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type)) + } +} + +// http01Challenge represents an http-01 acme challenge. +type http01Challenge struct { + *baseChallenge +} + +// newHTTP01Challenge returns a new acme http-01 challenge. +func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { + bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) + if err != nil { + return nil, err + } + bc.Type = "http-01" + bc.Value = ops.Identifier.Value + + hc := &http01Challenge{bc} + if err := hc.save(db, nil); err != nil { + return nil, err + } + return hc, nil +} + +// Validate attempts to validate the challenge. If the challenge has been +// satisfactorily validated, the 'status' and 'validated' attributes are +// updated. +func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { + // If already valid or invalid then return without performing validation. + if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid { + return hc, nil + } + url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token) + + resp, err := vo.httpGet(url) + if err != nil { + if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err, + "error doing http GET for url %s", url))); err != nil { + return nil, err + } + return hc, nil + } + if resp.StatusCode >= 400 { + if err = hc.storeError(db, + ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", + url, resp.StatusCode))); err != nil { + return nil, err + } + return hc, nil + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+ + "response body for url %s", url)) + } + keyAuth := strings.Trim(string(body), "\r\n") + + expected, err := KeyAuthorization(hc.Token, jwk) + if err != nil { + return nil, err + } + if keyAuth != expected { + if err = hc.storeError(db, + RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ + "expected %s, but got %s", expected, keyAuth))); err != nil { + return nil, err + } + return hc, nil + } + + // Update and store the challenge. + upd := &http01Challenge{hc.baseChallenge.clone()} + upd.Status = StatusValid + upd.Error = nil + upd.Validated = clock.Now() + + if err := upd.save(db, hc); err != nil { + return nil, err + } + return upd, nil +} + +// dns01Challenge represents an dns-01 acme challenge. +type dns01Challenge struct { + *baseChallenge +} + +// newDNS01Challenge returns a new acme dns-01 challenge. +func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { + bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) + if err != nil { + return nil, err + } + bc.Type = "dns-01" + bc.Value = ops.Identifier.Value + + dc := &dns01Challenge{bc} + if err := dc.save(db, nil); err != nil { + return nil, err + } + return dc, nil +} + +// KeyAuthorization creates the ACME key authorization value from a token +// and a jwk. +func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { + thumbprint, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) + } + encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) + return fmt.Sprintf("%s.%s", token, encPrint), nil +} + +// validate attempts to validate the challenge. If the challenge has been +// satisfactorily validated, the 'status' and 'validated' attributes are +// updated. +func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { + // If already valid or invalid then return without performing validation. + if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid { + return dc, nil + } + + txtRecords, err := vo.lookupTxt("_acme-challenge." + dc.Value) + if err != nil { + if err = dc.storeError(db, + DNSErr(errors.Wrapf(err, "error looking up TXT "+ + "records for domain %s", dc.Value))); err != nil { + return nil, err + } + return dc, nil + } + + expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) + if err != nil { + return nil, err + } + h := sha256.Sum256([]byte(expectedKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + var found bool + for _, r := range txtRecords { + if r == expected { + found = true + break + } + } + if !found { + if err = dc.storeError(db, + RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ + "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil { + return nil, err + } + return dc, nil + } + + // Update and store the challenge. + upd := &dns01Challenge{dc.baseChallenge.clone()} + upd.Status = StatusValid + upd.Error = nil + upd.Validated = time.Now().UTC() + + if err := upd.save(db, dc); err != nil { + return nil, err + } + return upd, nil +} + +// getChallenge retrieves and unmarshals an ACME challenge type from the database. +func getChallenge(db nosql.DB, id string) (challenge, error) { + b, err := db.Get(challengeTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id)) + } + ch, err := unmarshalChallenge(b) + if err != nil { + return nil, err + } + return ch, nil +} diff --git a/acme/challenge_test.go b/acme/challenge_test.go new file mode 100644 index 00000000..6291803d --- /dev/null +++ b/acme/challenge_test.go @@ -0,0 +1,1093 @@ +package acme + +import ( + "bytes" + "crypto" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/cli/jose" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +var testOps = ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "", // will get set correctly depending on the "new.." method. + Value: "zap.internal", + }, +} + +func newDNSCh() (challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newDNS01Challenge(mockdb, testOps) +} + +func newHTTPCh() (challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newHTTP01Challenge(mockdb, testOps) +} + +func TestNewHTTP01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "http", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newHTTP01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "http-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestNewDNS01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "dns", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newDNS01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "dns-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestChallengeToACME(t *testing.T) { + dir := newDirectory("ca.smallstep.com", "acme") + + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Validated = clock.Now() + + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + prov := newProv() + tests := map[string]challenge{ + "dns": dnsCh, + "http": httpCh, + } + for name, ch := range tests { + t.Run(name, func(t *testing.T) { + ach, err := ch.toACME(nil, dir, prov) + assert.FatalError(t, err) + + assert.Equals(t, ach.Type, ch.getType()) + assert.Equals(t, ach.Status, ch.getStatus()) + assert.Equals(t, ach.Token, ch.getToken()) + assert.Equals(t, ach.URL, + fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/%s", + URLSafeProvisionerName(prov), ch.getID())) + assert.Equals(t, ach.ID, ch.getID()) + assert.Equals(t, ach.AuthzID, ch.getAuthzID()) + + if ach.Type == "http-01" { + v, err := time.Parse(time.RFC3339, ach.Validated) + assert.FatalError(t, err) + assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) + } else { + assert.Equals(t, ach.Validated, "") + } + }) + } +} + +func TestChallengeSave(t *testing.T) { + type test struct { + ch challenge + old challenge + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + } + }, + "fail/old-nil/swap-false": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), + } + }, + "ok/old-nil": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, []byte(httpCh.getID()), key) + return []byte("foo"), true, nil + }, + }, + } + }, + "ok/old-not-nil": func(t *testing.T) test { + oldHTTPCh, err := newHTTPCh() + assert.FatalError(t, err) + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + + oldb, err := json.Marshal(oldHTTPCh) + assert.FatalError(t, err) + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: oldHTTPCh, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, oldb) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, []byte(httpCh.getID()), key) + return []byte("foo"), true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.ch.save(tc.db, tc.old); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestChallengeClone(t *testing.T) { + ch, err := newHTTPCh() + assert.FatalError(t, err) + + clone := ch.clone() + + assert.Equals(t, clone.getID(), ch.getID()) + assert.Equals(t, clone.getAccountID(), ch.getAccountID()) + assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, clone.getStatus(), ch.getStatus()) + assert.Equals(t, clone.getToken(), ch.getToken()) + assert.Equals(t, clone.getCreated(), ch.getCreated()) + assert.Equals(t, clone.getValidated(), ch.getValidated()) + + clone.Status = StatusValid + + assert.NotEquals(t, clone.getStatus(), ch.getStatus()) +} + +func TestChallengeUnmarshal(t *testing.T) { + type test struct { + ch challenge + chb []byte + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/nil": func(t *testing.T) test { + return test{ + chb: nil, + err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), + } + }, + "fail/unexpected-type": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Type = "foo" + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), + } + }, + "ok/dns": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + ch: dnsCh, + chb: b, + } + }, + "ok/http": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + chb: b, + } + }, + "ok/err": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + chb: b, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := unmarshalChallenge(tc.chb); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.getID(), ch.getID()) + assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) + assert.Equals(t, tc.ch.getToken(), ch.getToken()) + assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) + assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) + } + } + }) + } +} +func TestGetChallenge(t *testing.T) { + type test struct { + id string + db nosql.DB + ch challenge + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), + } + }, + "fail/db-error": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + _dnsCh, ok := dnsCh.(*dns01Challenge) + assert.Fatal(t, ok) + _dnsCh.baseChallenge.Type = "foo" + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(dnsCh.getID())) + return b, nil + }, + }, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), + } + }, + "ok": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(dnsCh.getID())) + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := getChallenge(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.getID(), ch.getID()) + assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) + assert.Equals(t, tc.ch.getToken(), ch.getToken()) + assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) + assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) + } + } + }) + } +} + +func TestKeyAuthorization(t *testing.T) { + type test struct { + token string + jwk *jose.JSONWebKey + exp string + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/jwk-thumbprint-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + return test{ + token: "1234", + jwk: jwk, + err: ServerInternalErr(errors.Errorf("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + } + }, + "ok": func(t *testing.T) test { + token := "1234" + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + thumbprint, err := jwk.Thumbprint(crypto.SHA256) + assert.FatalError(t, err) + encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) + return test{ + token: token, + jwk: jwk, + exp: fmt.Sprintf("%s.%s", token, encPrint), + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.exp, ka) + } + } + }) + } +} + +type errReader int + +func (errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("force") +} +func (errReader) Close() error { + return nil +} + +func TestHTTP01Validate(t *testing.T) { + type test struct { + vo validateOptions + ch challenge + res challenge + jwk *jose.JSONWebKey + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/status-already-valid": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + _ch, ok := ch.(*http01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusValid + return test{ + ch: ch, + res: ch, + } + }, + "ok/status-already-invalid": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + _ch, ok := ch.(*http01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusInvalid + return test{ + ch: ch, + res: ch, + } + }, + "ok/http-get-error": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ + "http://zap.internal/.well-known/acme-challenge/%s: force", ch.getToken())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &http01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + return test{ + ch: ch, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "ok/http-get->=400": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ + "http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.getToken())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &http01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + return test{ + ch: ch, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + }, nil + }, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "fail/read-body": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + + return test{ + ch: ch, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: errReader(0), + }, nil + }, + }, + jwk: jwk, + err: ServerInternalErr(errors.Errorf("error reading response "+ + "body for url http://zap.internal/.well-known/acme-challenge/%s: force", + ch.getToken())), + } + }, + "fail/key-authorization-gen-error": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + return test{ + ch: ch, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), + }, nil + }, + }, + jwk: jwk, + err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + } + }, + "ok/key-auth-mismatch": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + + expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ + "expected %s, but got foo", expKeyAuth)) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &http01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + + return test{ + ch: ch, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), + }, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "fail/save-error": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + return test{ + ch: ch, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + }, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + } + }, + "ok": func(t *testing.T) test { + ch, err := newHTTPCh() + assert.FatalError(t, err) + _ch, ok := ch.(*http01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Error = MalformedErr(nil).ToACME() + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + + baseClone := ch.clone() + baseClone.Status = StatusValid + baseClone.Error = nil + newCh := &http01Challenge{baseClone} + + return test{ + ch: ch, + res: newCh, + vo: validateOptions{ + httpGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + }, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + + httpCh, err := unmarshalChallenge(newval) + assert.FatalError(t, err) + assert.Equals(t, httpCh.getStatus(), StatusValid) + assert.True(t, httpCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, httpCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + + baseClone.Validated = httpCh.getValidated() + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res.getID(), ch.getID()) + assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.res.getStatus(), ch.getStatus()) + assert.Equals(t, tc.res.getToken(), ch.getToken()) + assert.Equals(t, tc.res.getCreated(), ch.getCreated()) + assert.Equals(t, tc.res.getValidated(), ch.getValidated()) + assert.Equals(t, tc.res.getError(), ch.getError()) + } + } + }) + } +} + +func TestDNS01Validate(t *testing.T) { + type test struct { + vo validateOptions + ch challenge + res challenge + jwk *jose.JSONWebKey + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/status-already-valid": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusValid + return test{ + ch: ch, + res: ch, + } + }, + "ok/status-already-invalid": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Status = StatusInvalid + return test{ + ch: ch, + res: ch, + } + }, + "ok/lookup-txt-error": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + expErr := DNSErr(errors.Errorf("error looking up TXT records for "+ + "domain %s: force", ch.getValue())) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &dns01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "fail/key-authorization-gen-error": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwk.Key = "foo" + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil + }, + }, + jwk: jwk, + err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + } + }, + "ok/key-auth-mismatch": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + + expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ + "expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})) + baseClone := ch.clone() + baseClone.Error = expErr.ToACME() + newCh := &http01Challenge{baseClone} + newb, err := json.Marshal(newCh) + assert.FatalError(t, err) + + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + assert.Equals(t, newval, newb) + return nil, true, nil + }, + }, + res: ch, + } + }, + "fail/save-error": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + return test{ + ch: ch, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + } + }, + "ok": func(t *testing.T) test { + ch, err := newDNSCh() + assert.FatalError(t, err) + _ch, ok := ch.(*dns01Challenge) + assert.Fatal(t, ok) + _ch.baseChallenge.Error = MalformedErr(nil).ToACME() + oldb, err := json.Marshal(ch) + assert.FatalError(t, err) + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + assert.FatalError(t, err) + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + + baseClone := ch.clone() + baseClone.Status = StatusValid + baseClone.Error = nil + newCh := &dns01Challenge{baseClone} + + return test{ + ch: ch, + res: newCh, + vo: validateOptions{ + lookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + jwk: jwk, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(ch.getID())) + assert.Equals(t, old, oldb) + + dnsCh, err := unmarshalChallenge(newval) + assert.FatalError(t, err) + assert.Equals(t, dnsCh.getStatus(), StatusValid) + assert.True(t, dnsCh.getValidated().Before(time.Now().UTC())) + assert.True(t, dnsCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + + baseClone.Validated = dnsCh.getValidated() + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res.getID(), ch.getID()) + assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.res.getStatus(), ch.getStatus()) + assert.Equals(t, tc.res.getToken(), ch.getToken()) + assert.Equals(t, tc.res.getCreated(), ch.getCreated()) + assert.Equals(t, tc.res.getValidated(), ch.getValidated()) + assert.Equals(t, tc.res.getError(), ch.getError()) + } + } + }) + } +} diff --git a/acme/common.go b/acme/common.go new file mode 100644 index 00000000..577c35cd --- /dev/null +++ b/acme/common.go @@ -0,0 +1,76 @@ +package acme + +import ( + "crypto/x509" + "net/url" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/crypto/randutil" +) + +// SignAuthority is the interface implemented by a CA authority. +type SignAuthority interface { + Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) + LoadProvisionerByID(string) (provisioner.Interface, error) +} + +// Identifier encodes the type that an order pertains to. +type Identifier struct { + Type string `json:"type"` + Value string `json:"value"` +} + +var ( + accountTable = []byte("acme-accounts") + accountByKeyIDTable = []byte("acme-keyID-accountID-index") + authzTable = []byte("acme-authzs") + challengeTable = []byte("acme-challenges") + nonceTable = []byte("nonce-table") + orderTable = []byte("acme-orders") + ordersByAccountIDTable = []byte("acme-account-orders-index") + certTable = []byte("acme-certs") +) + +var ( + // StatusValid -- valid + StatusValid = "valid" + // StatusInvalid -- invalid + StatusInvalid = "invalid" + // StatusPending -- pending; e.g. an Order that is not ready to be finalized. + StatusPending = "pending" + // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid. + StatusDeactivated = "deactivated" + // StatusReady -- ready; e.g. for an Order that is ready to be finalized. + StatusReady = "ready" + //statusExpired = "expired" + //statusActive = "active" + //statusProcessing = "processing" +) + +var idLen = 32 + +func randID() (val string, err error) { + val, err = randutil.Alphanumeric(idLen) + if err != nil { + return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID")) + } + return val, nil +} + +// Clock that returns time in UTC rounded to seconds. +type Clock int + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Round(time.Second) +} + +var clock = new(Clock) + +// URLSafeProvisionerName returns a path escaped version of the ACME provisioner +// ID that is safe to use in URL paths. +func URLSafeProvisionerName(p provisioner.Interface) string { + return url.PathEscape(p.GetName()) +} diff --git a/acme/directory.go b/acme/directory.go new file mode 100644 index 00000000..85819f10 --- /dev/null +++ b/acme/directory.go @@ -0,0 +1,120 @@ +package acme + +import ( + "encoding/json" + "fmt" + + "github.com/pkg/errors" +) + +// Directory represents an ACME directory for configuring clients. +type Directory struct { + NewNonce string `json:"newNonce,omitempty"` + NewAccount string `json:"newAccount,omitempty"` + NewOrder string `json:"newOrder,omitempty"` + NewAuthz string `json:"newAuthz,omitempty"` + RevokeCert string `json:"revokeCert,omitempty"` + KeyChange string `json:"keyChange,omitempty"` +} + +// ToLog enables response logging for the Directory type. +func (d *Directory) ToLog() (interface{}, error) { + b, err := json.Marshal(d) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging")) + } + return string(b), nil +} + +type directory struct { + prefix, dns string +} + +// newDirectory returns a new Directory type. +func newDirectory(dns, prefix string) *directory { + return &directory{prefix: prefix, dns: dns} +} + +// Link captures the link type. +type Link int + +const ( + // NewNonceLink new-nonce + NewNonceLink Link = iota + // NewAccountLink new-account + NewAccountLink + // AccountLink account + AccountLink + // OrderLink order + OrderLink + // NewOrderLink new-order + NewOrderLink + // OrdersByAccountLink list of orders owned by account + OrdersByAccountLink + // FinalizeLink finalize order + FinalizeLink + // NewAuthzLink authz + NewAuthzLink + // AuthzLink new-authz + AuthzLink + // ChallengeLink challenge + ChallengeLink + // CertificateLink certificate + CertificateLink + // DirectoryLink directory + DirectoryLink + // RevokeCertLink revoke certificate + RevokeCertLink + // KeyChangeLink key rollover + KeyChangeLink +) + +func (l Link) String() string { + switch l { + case NewNonceLink: + return "new-nonce" + case NewAccountLink: + return "new-account" + case AccountLink: + return "account" + case NewOrderLink: + return "new-order" + case OrderLink: + return "order" + case NewAuthzLink: + return "new-authz" + case AuthzLink: + return "authz" + case ChallengeLink: + return "challenge" + case CertificateLink: + return "certificate" + case DirectoryLink: + return "directory" + case RevokeCertLink: + return "revoke-cert" + case KeyChangeLink: + return "key-change" + default: + return "unexpected" + } +} + +// getLink returns an absolute or partial path to the given resource. +func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string { + var link string + switch typ { + case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: + link = fmt.Sprintf("/%s/%s", provisionerName, typ.String()) + case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink: + link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0]) + case OrdersByAccountLink: + link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0]) + case FinalizeLink: + link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) + } + if abs { + return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link) + } + return link +} diff --git a/acme/directory_test.go b/acme/directory_test.go new file mode 100644 index 00000000..14fd421f --- /dev/null +++ b/acme/directory_test.go @@ -0,0 +1,63 @@ +package acme + +import ( + "fmt" + "testing" + + "github.com/smallstep/assert" +) + +func TestDirectoryGetLink(t *testing.T) { + dns := "ca.smallstep.com" + prefix := "acme" + dir := newDirectory(dns, prefix) + id := "1234" + + prov := newProv() + provID := URLSafeProvisionerName(prov) + + type newTest struct { + actual, expected string + } + assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID)) + assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID)) + + assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID)) + assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID)) + + assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID)) + assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID)) + + assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID)) + assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID)) + + assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID)) + assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID)) + + assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID)) + assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID)) + + assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID)) + assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) + + assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID)) + assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID)) + + assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID)) + assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID)) + + assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID)) + assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID)) + + assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID)) + assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID)) + + assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID)) + assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID)) + + assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID)) + assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID)) + + assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID)) + assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID)) +} diff --git a/acme/errors.go b/acme/errors.go new file mode 100644 index 00000000..9facac5f --- /dev/null +++ b/acme/errors.go @@ -0,0 +1,439 @@ +package acme + +import ( + "github.com/pkg/errors" +) + +// AccountDoesNotExistErr returns a new acme error. +func AccountDoesNotExistErr(err error) *Error { + return &Error{ + Type: accountDoesNotExistErr, + Detail: "Account does not exist", + Status: 404, + Err: err, + } +} + +// AlreadyRevokedErr returns a new acme error. +func AlreadyRevokedErr(err error) *Error { + return &Error{ + Type: alreadyRevokedErr, + Detail: "Certificate already revoked", + Status: 400, + Err: err, + } +} + +// BadCSRErr returns a new acme error. +func BadCSRErr(err error) *Error { + return &Error{ + Type: badCSRErr, + Detail: "The CSR is unacceptable", + Status: 400, + Err: err, + } +} + +// BadNonceErr returns a new acme error. +func BadNonceErr(err error) *Error { + return &Error{ + Type: badNonceErr, + Detail: "Unacceptable anti-replay nonce", + Status: 400, + Err: err, + } +} + +// BadPublicKeyErr returns a new acme error. +func BadPublicKeyErr(err error) *Error { + return &Error{ + Type: badPublicKeyErr, + Detail: "The jws was signed by a public key the server does not support", + Status: 400, + Err: err, + } +} + +// BadRevocationReasonErr returns a new acme error. +func BadRevocationReasonErr(err error) *Error { + return &Error{ + Type: badRevocationReasonErr, + Detail: "The revocation reason provided is not allowed by the server", + Status: 400, + Err: err, + } +} + +// BadSignatureAlgorithmErr returns a new acme error. +func BadSignatureAlgorithmErr(err error) *Error { + return &Error{ + Type: badSignatureAlgorithmErr, + Detail: "The JWS was signed with an algorithm the server does not support", + Status: 400, + Err: err, + } +} + +// CaaErr returns a new acme error. +func CaaErr(err error) *Error { + return &Error{ + Type: caaErr, + Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate", + Status: 400, + Err: err, + } +} + +// CompoundErr returns a new acme error. +func CompoundErr(err error) *Error { + return &Error{ + Type: compoundErr, + Detail: "Specific error conditions are indicated in the “subproblems” array", + Status: 400, + Err: err, + } +} + +// ConnectionErr returns a new acme error. +func ConnectionErr(err error) *Error { + return &Error{ + Type: connectionErr, + Detail: "The server could not connect to validation target", + Status: 400, + Err: err, + } +} + +// DNSErr returns a new acme error. +func DNSErr(err error) *Error { + return &Error{ + Type: dnsErr, + Detail: "There was a problem with a DNS query during identifier validation", + Status: 400, + Err: err, + } +} + +// ExternalAccountRequiredErr returns a new acme error. +func ExternalAccountRequiredErr(err error) *Error { + return &Error{ + Type: externalAccountRequiredErr, + Detail: "The request must include a value for the \"externalAccountBinding\" field", + Status: 400, + Err: err, + } +} + +// IncorrectResponseErr returns a new acme error. +func IncorrectResponseErr(err error) *Error { + return &Error{ + Type: incorrectResponseErr, + Detail: "Response received didn't match the challenge's requirements", + Status: 400, + Err: err, + } +} + +// InvalidContactErr returns a new acme error. +func InvalidContactErr(err error) *Error { + return &Error{ + Type: invalidContactErr, + Detail: "A contact URL for an account was invalid", + Status: 400, + Err: err, + } +} + +// MalformedErr returns a new acme error. +func MalformedErr(err error) *Error { + return &Error{ + Type: malformedErr, + Detail: "The request message was malformed", + Status: 400, + Err: err, + } +} + +// OrderNotReadyErr returns a new acme error. +func OrderNotReadyErr(err error) *Error { + return &Error{ + Type: orderNotReadyErr, + Detail: "The request attempted to finalize an order that is not ready to be finalized", + Status: 400, + Err: err, + } +} + +// RateLimitedErr returns a new acme error. +func RateLimitedErr(err error) *Error { + return &Error{ + Type: rateLimitedErr, + Detail: "The request exceeds a rate limit", + Status: 400, + Err: err, + } +} + +// RejectedIdentifierErr returns a new acme error. +func RejectedIdentifierErr(err error) *Error { + return &Error{ + Type: rejectedIdentifierErr, + Detail: "The server will not issue certificates for the identifier", + Status: 400, + Err: err, + } +} + +// ServerInternalErr returns a new acme error. +func ServerInternalErr(err error) *Error { + return &Error{ + Type: serverInternalErr, + Detail: "The server experienced an internal error", + Status: 500, + Err: err, + } +} + +// TLSErr returns a new acme error. +func TLSErr(err error) *Error { + return &Error{ + Type: tlsErr, + Detail: "The server received a TLS error during validation", + Status: 400, + Err: err, + } +} + +// UnauthorizedErr returns a new acme error. +func UnauthorizedErr(err error) *Error { + return &Error{ + Type: unauthorizedErr, + Detail: "The client lacks sufficient authorization", + Status: 401, + Err: err, + } +} + +// UnsupportedContactErr returns a new acme error. +func UnsupportedContactErr(err error) *Error { + return &Error{ + Type: unsupportedContactErr, + Detail: "A contact URL for an account used an unsupported protocol scheme", + Status: 400, + Err: err, + } +} + +// UnsupportedIdentifierErr returns a new acme error. +func UnsupportedIdentifierErr(err error) *Error { + return &Error{ + Type: unsupportedIdentifierErr, + Detail: "An identifier is of an unsupported type", + Status: 400, + Err: err, + } +} + +// UserActionRequiredErr returns a new acme error. +func UserActionRequiredErr(err error) *Error { + return &Error{ + Type: userActionRequiredErr, + Detail: "Visit the “instance” URL and take actions specified there", + Status: 400, + Err: err, + } +} + +// ProbType is the type of the ACME problem. +type ProbType int + +const ( + // The request specified an account that does not exist + accountDoesNotExistErr ProbType = iota + // The request specified a certificate to be revoked that has already been revoked + alreadyRevokedErr + // The CSR is unacceptable (e.g., due to a short key) + badCSRErr + // The client sent an unacceptable anti-replay nonce + badNonceErr + // The JWS was signed by a public key the server does not support + badPublicKeyErr + // The revocation reason provided is not allowed by the server + badRevocationReasonErr + // The JWS was signed with an algorithm the server does not support + badSignatureAlgorithmErr + // Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate + caaErr + // Specific error conditions are indicated in the “subproblems” array. + compoundErr + // The server could not connect to validation target + connectionErr + // There was a problem with a DNS query during identifier validation + dnsErr + // The request must include a value for the “externalAccountBinding” field + externalAccountRequiredErr + // Response received didn’t match the challenge’s requirements + incorrectResponseErr + // A contact URL for an account was invalid + invalidContactErr + // The request message was malformed + malformedErr + // The request attempted to finalize an order that is not ready to be finalized + orderNotReadyErr + // The request exceeds a rate limit + rateLimitedErr + // The server will not issue certificates for the identifier + rejectedIdentifierErr + // The server experienced an internal error + serverInternalErr + // The server received a TLS error during validation + tlsErr + // The client lacks sufficient authorization + unauthorizedErr + // A contact URL for an account used an unsupported protocol scheme + unsupportedContactErr + // An identifier is of an unsupported type + unsupportedIdentifierErr + // Visit the “instance” URL and take actions specified there + userActionRequiredErr +) + +// String returns the string representation of the acme problem type, +// fulfilling the Stringer interface. +func (ap ProbType) String() string { + switch ap { + case accountDoesNotExistErr: + return "accountDoesNotExist" + case alreadyRevokedErr: + return "alreadyRevoked" + case badCSRErr: + return "badCSR" + case badNonceErr: + return "badNonce" + case badPublicKeyErr: + return "badPublicKey" + case badRevocationReasonErr: + return "badRevocationReason" + case badSignatureAlgorithmErr: + return "badSignatureAlgorithm" + case caaErr: + return "caa" + case compoundErr: + return "compound" + case connectionErr: + return "connection" + case dnsErr: + return "dns" + case externalAccountRequiredErr: + return "externalAccountRequired" + case incorrectResponseErr: + return "incorrectResponse" + case invalidContactErr: + return "invalidContact" + case malformedErr: + return "malformed" + case orderNotReadyErr: + return "orderNotReady" + case rateLimitedErr: + return "rateLimited" + case rejectedIdentifierErr: + return "rejectedIdentifier" + case serverInternalErr: + return "serverInternal" + case tlsErr: + return "tls" + case unauthorizedErr: + return "unauthorized" + case unsupportedContactErr: + return "unsupportedContact" + case unsupportedIdentifierErr: + return "unsupportedIdentifier" + case userActionRequiredErr: + return "userActionRequired" + default: + return "unsupported type" + } +} + +// Error is an ACME error type complete with problem document. +type Error struct { + Type ProbType + Detail string + Err error + Status int + Sub []*Error + Identifier *Identifier +} + +// Wrap attempts to wrap the internal error. +func Wrap(err error, wrap string) *Error { + switch e := err.(type) { + case nil: + return nil + case *Error: + if e.Err == nil { + e.Err = errors.New(wrap + "; " + e.Detail) + } else { + e.Err = errors.Wrap(e.Err, wrap) + } + return e + default: + return ServerInternalErr(errors.Wrap(err, wrap)) + } +} + +// Error implements the error interface. +func (e *Error) Error() string { + if e.Err == nil { + return e.Detail + } + return e.Err.Error() +} + +// Cause returns the internal error and implements the Causer interface. +func (e *Error) Cause() error { + if e.Err == nil { + return errors.New(e.Detail) + } + return e.Err +} + +// ToACME returns an acme representation of the problem type. +func (e *Error) ToACME() *AError { + ae := &AError{ + Type: "urn:ietf:params:acme:error:" + e.Type.String(), + Detail: e.Error(), + Status: e.Status, + } + if e.Identifier != nil { + ae.Identifier = *e.Identifier + } + for _, p := range e.Sub { + ae.Subproblems = append(ae.Subproblems, p.ToACME()) + } + return ae +} + +// StatusCode returns the status code and implements the StatusCode interface. +func (e *Error) StatusCode() int { + return e.Status +} + +// AError is the error type as seen in acme request/responses. +type AError struct { + Type string `json:"type"` + Detail string `json:"detail"` + Identifier interface{} `json:"identifier,omitempty"` + Subproblems []interface{} `json:"subproblems,omitempty"` + Status int `json:"-"` +} + +// Error allows AError to implement the error interface. +func (ae *AError) Error() string { + return ae.Detail +} + +// StatusCode returns the status code and implements the StatusCode interface. +func (ae *AError) StatusCode() int { + return ae.Status +} diff --git a/acme/nonce.go b/acme/nonce.go new file mode 100644 index 00000000..db680f08 --- /dev/null +++ b/acme/nonce.go @@ -0,0 +1,73 @@ +package acme + +import ( + "encoding/base64" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +// nonce contains nonce metadata used in the ACME protocol. +type nonce struct { + ID string + Created time.Time +} + +// newNonce creates, stores, and returns an ACME replay-nonce. +func newNonce(db nosql.DB) (*nonce, error) { + _id, err := randID() + if err != nil { + return nil, err + } + + id := base64.RawURLEncoding.EncodeToString([]byte(_id)) + n := &nonce{ + ID: id, + Created: clock.Now(), + } + b, err := json.Marshal(n) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) + } + _, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b) + switch { + case err != nil: + return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce")) + case !swapped: + return nil, ServerInternalErr(errors.New("error storing nonce; " + + "value has changed since last read")) + default: + return n, nil + } +} + +// useNonce verifies that the nonce is valid (by checking if it exists), +// and if so, consumes the nonce resource by deleting it from the database. +func useNonce(db nosql.DB, nonce string) error { + err := db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Get, + }, + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Delete, + }, + }, + }) + + switch { + case nosql.IsErrNotFound(err): + return BadNonceErr(nil) + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce)) + default: + return nil + } +} diff --git a/acme/nonce_test.go b/acme/nonce_test.go new file mode 100644 index 00000000..6aa467a0 --- /dev/null +++ b/acme/nonce_test.go @@ -0,0 +1,163 @@ +package acme + +import ( + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +func TestNewNonce(t *testing.T) { + type test struct { + db nosql.DB + err *Error + id *string + } + tests := map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error storing nonce: force")), + } + }, + "fail/cmpAndSwap-false": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + return nil, false, nil + }, + }, + err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")), + } + }, + "ok": func(t *testing.T) test { + var _id string + id := &_id + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + *id = string(key) + return nil, true, nil + }, + }, + id: id, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if n, err := newNonce(tc.db); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, n.ID, *tc.id) + + assert.True(t, n.Created.Before(time.Now().Add(time.Minute))) + assert.True(t, n.Created.After(time.Now().Add(-time.Minute))) + } + } + }) + } +} + +func TestUseNonce(t *testing.T) { + type test struct { + id string + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/update-not-found": func(t *testing.T) test { + id := "foo" + return test{ + db: &db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(id)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(id)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return database.ErrNotFound + }, + }, + id: id, + err: BadNonceErr(nil), + } + }, + "fail/update-error": func(t *testing.T) test { + id := "foo" + return test{ + db: &db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(id)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(id)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return errors.New("force") + }, + }, + id: id, + err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)), + } + }, + "ok": func(t *testing.T) test { + id := "foo" + return test{ + db: &db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(id)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(id)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + + return nil + }, + }, + id: id, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := useNonce(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } + }) + } +} diff --git a/acme/order.go b/acme/order.go new file mode 100644 index 00000000..ab4ac659 --- /dev/null +++ b/acme/order.go @@ -0,0 +1,342 @@ +package acme + +import ( + "context" + "crypto/x509" + "encoding/json" + "reflect" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/nosql" +) + +var defaultOrderExpiry = time.Hour * 24 + +// Order contains order metadata for the ACME protocol order type. +type Order struct { + Status string `json:"status"` + Expires string `json:"expires,omitempty"` + Identifiers []Identifier `json:"identifiers"` + NotBefore string `json:"notBefore,omitempty"` + NotAfter string `json:"notAfter,omitempty"` + Error interface{} `json:"error,omitempty"` + Authorizations []string `json:"authorizations"` + Finalize string `json:"finalize"` + Certificate string `json:"certificate,omitempty"` + ID string `json:"-"` +} + +// ToLog enables response logging. +func (o *Order) ToLog() (interface{}, error) { + b, err := json.Marshal(o) + if err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging")) + } + return string(b), nil +} + +// GetID returns the Order ID. +func (o *Order) GetID() string { + return o.ID +} + +// OrderOptions options with which to create a new Order. +type OrderOptions struct { + AccountID string `json:"accID"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` +} + +type order struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires,omitempty"` + Status string `json:"status"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + Error *Error `json:"error,omitempty"` + Authorizations []string `json:"authorizations"` + Certificate string `json:"certificate,omitempty"` +} + +// newOrder returns a new Order type. +func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { + id, err := randID() + if err != nil { + return nil, err + } + + authzs := make([]string, len(ops.Identifiers)) + for i, identifier := range ops.Identifiers { + authz, err := newAuthz(db, ops.AccountID, identifier) + if err != nil { + return nil, err + } + authzs[i] = authz.getID() + } + + now := clock.Now() + o := &order{ + ID: id, + AccountID: ops.AccountID, + Created: now, + Status: StatusPending, + Expires: now.Add(defaultOrderExpiry), + Identifiers: ops.Identifiers, + NotBefore: ops.NotBefore, + NotAfter: ops.NotAfter, + Authorizations: authzs, + } + if err := o.save(db, nil); err != nil { + return nil, err + } + + // Update the "order IDs by account ID" index // + oids, err := getOrderIDsByAccount(db, ops.AccountID) + if err != nil { + return nil, err + } + newOids := append(oids, o.ID) + if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil { + db.Del(orderTable, []byte(o.ID)) + return nil, err + } + return o, nil +} + +type orderIDs []string + +func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { + var ( + err error + oldb []byte + ) + if len(old) == 0 { + oldb = nil + } else { + oldb, err = json.Marshal(old) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) + } + } + newb, err := json.Marshal(oids) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) + } + _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) + switch { + case err != nil: + return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) + case !swapped: + return ServerInternalErr(errors.Errorf("error storing order IDs "+ + "for account %s; order IDs changed since last read", accID)) + default: + return nil + } +} + +func (o *order) save(db nosql.DB, old *order) error { + var ( + err error + oldB []byte + ) + if old == nil { + oldB = nil + } else { + if oldB, err = json.Marshal(old); err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) + } + } + + newB, err := json.Marshal(o) + if err != nil { + return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order")) + } + + _, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB) + switch { + case err != nil: + return ServerInternalErr(errors.Wrap(err, "error storing order")) + case !swapped: + return ServerInternalErr(errors.New("error storing order; " + + "value has changed since last read")) + default: + return nil + } +} + +// updateStatus updates order status if necessary. +func (o *order) updateStatus(db nosql.DB) (*order, error) { + _newOrder := *o + newOrder := &_newOrder + + now := time.Now().UTC() + switch o.Status { + case StatusInvalid: + return o, nil + case StatusValid: + return o, nil + case StatusReady: + // check expiry + if now.After(o.Expires) { + newOrder.Status = StatusInvalid + newOrder.Error = MalformedErr(errors.New("order has expired")) + break + } + return o, nil + case StatusPending: + // check expiry + if now.After(o.Expires) { + newOrder.Status = StatusInvalid + newOrder.Error = MalformedErr(errors.New("order has expired")) + break + } + + var count = map[string]int{ + StatusValid: 0, + StatusInvalid: 0, + StatusPending: 0, + } + for _, azID := range o.Authorizations { + authz, err := getAuthz(db, azID) + if err != nil { + return nil, err + } + if authz, err = authz.updateStatus(db); err != nil { + return nil, err + } + st := authz.getStatus() + count[st]++ + } + switch { + case count[StatusInvalid] > 0: + newOrder.Status = StatusInvalid + case count[StatusPending] > 0: + break + case count[StatusValid] == len(o.Authorizations): + newOrder.Status = StatusReady + default: + return nil, ServerInternalErr(errors.New("unexpected authz status")) + } + default: + return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) + } + + if err := newOrder.save(db, o); err != nil { + return nil, err + } + return newOrder, nil +} + +// finalize signs a certificate if the necessary conditions for Order completion +// have been met. +func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) { + var err error + if o, err = o.updateStatus(db); err != nil { + return nil, err + } + switch o.Status { + case StatusInvalid: + return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) + case StatusValid: + return o, nil + case StatusPending: + return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) + case StatusReady: + break + default: + return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) + } + + // Validate identifier names against CSR alternative names // + csrNames := make(map[string]int) + for _, n := range csr.DNSNames { + csrNames[n] = 1 + } + orderNames := make(map[string]int) + for _, n := range o.Identifiers { + orderNames[n.Value] = 1 + } + if !reflect.DeepEqual(csrNames, orderNames) { + return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")) + } + + // Get authorizations from the ACME provisioner. + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + signOps, err := p.AuthorizeSign(ctx, "") + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) + } + + // Create and store a new certificate. + leaf, inter, err := auth.Sign(csr, provisioner.Options{ + NotBefore: provisioner.NewTimeDuration(o.NotBefore), + NotAfter: provisioner.NewTimeDuration(o.NotAfter), + }, signOps...) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) + } + + cert, err := newCert(db, CertOptions{ + AccountID: o.AccountID, + OrderID: o.ID, + Leaf: leaf, + Intermediates: []*x509.Certificate{inter}, + }) + if err != nil { + return nil, err + } + + _newOrder := *o + newOrder := &_newOrder + newOrder.Certificate = cert.ID + newOrder.Status = StatusValid + if err := newOrder.save(db, o); err != nil { + return nil, err + } + return newOrder, nil +} + +// getOrder retrieves and unmarshals an ACME Order type from the database. +func getOrder(db nosql.DB, id string) (*order, error) { + b, err := db.Get(orderTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id)) + } else if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id)) + } + var o order + if err := json.Unmarshal(b, &o); err != nil { + return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order")) + } + return &o, nil +} + +// toACME converts the internal Order type into the public acmeOrder type for +// presentation in the ACME protocol. +func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) { + azs := make([]string, len(o.Authorizations)) + for i, aid := range o.Authorizations { + azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid) + } + ao := &Order{ + Status: o.Status, + Expires: o.Expires.Format(time.RFC3339), + Identifiers: o.Identifiers, + NotBefore: o.NotBefore.Format(time.RFC3339), + NotAfter: o.NotAfter.Format(time.RFC3339), + Authorizations: azs, + Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID), + ID: o.ID, + } + + if o.Certificate != "" { + ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate) + } + return ao, nil +} diff --git a/acme/order_test.go b/acme/order_test.go new file mode 100644 index 00000000..31601fae --- /dev/null +++ b/acme/order_test.go @@ -0,0 +1,1129 @@ +package acme + +import ( + "context" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +var certDuration = 6 * time.Hour + +func defaultOrderOps() OrderOptions { + return OrderOptions{ + AccountID: "accID", + Identifiers: []Identifier{ + {Type: "dns", Value: "acme.example.com"}, + {Type: "dns", Value: "step.example.com"}, + }, + NotBefore: clock.Now(), + NotAfter: clock.Now().Add(certDuration), + } +} + +func newO() (*order, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + MGet: func(bucket, key []byte) ([]byte, error) { + b, err := json.Marshal([]string{"1", "2"}) + if err != nil { + return nil, err + } + return b, nil + }, + } + return newOrder(mockdb, defaultOrderOps()) +} + +func TestGetOrder(t *testing.T) { + type test struct { + id string + db nosql.DB + o *order + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{ + o: o, + id: o.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("order %s not found: not found", o.ID)), + } + }, + "fail/db-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{ + o: o, + id: o.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error loading order %s: force", o.ID)), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{ + o: o, + id: o.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return nil, nil + }, + }, + err: ServerInternalErr(errors.New("error unmarshaling order: unexpected end of JSON input")), + } + }, + "ok": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + return test{ + o: o, + id: o.ID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(o.ID)) + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if o, err := getOrder(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.o.ID, o.ID) + assert.Equals(t, tc.o.AccountID, o.AccountID) + assert.Equals(t, tc.o.Status, o.Status) + assert.Equals(t, tc.o.Identifiers, o.Identifiers) + assert.Equals(t, tc.o.Created, o.Created) + assert.Equals(t, tc.o.Expires, o.Expires) + assert.Equals(t, tc.o.Authorizations, o.Authorizations) + assert.Equals(t, tc.o.NotBefore, o.NotBefore) + assert.Equals(t, tc.o.NotAfter, o.NotAfter) + assert.Equals(t, tc.o.Certificate, o.Certificate) + assert.Equals(t, tc.o.Error, o.Error) + } + } + }) + } +} + +func TestOrderToACME(t *testing.T) { + dir := newDirectory("ca.smallstep.com", "acme") + prov := newProv() + + type test struct { + o *order + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/no-cert": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{o: o} + }, + "ok/cert": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusValid + o.Certificate = "cert-id" + return test{o: o} + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + acmeOrder, err := tc.o.toACME(nil, dir, prov) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acmeOrder.ID, tc.o.ID) + assert.Equals(t, acmeOrder.Status, tc.o.Status) + assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers) + assert.Equals(t, acmeOrder.Finalize, fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s/finalize", URLSafeProvisionerName(prov), tc.o.ID)) + if tc.o.Certificate != "" { + assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/%s", URLSafeProvisionerName(prov), tc.o.Certificate)) + } + + expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires) + assert.FatalError(t, err) + assert.Equals(t, expiry.String(), tc.o.Expires.String()) + nbf, err := time.Parse(time.RFC3339, acmeOrder.NotBefore) + assert.FatalError(t, err) + assert.Equals(t, nbf.String(), tc.o.NotBefore.String()) + naf, err := time.Parse(time.RFC3339, acmeOrder.NotAfter) + assert.FatalError(t, err) + assert.Equals(t, naf.String(), tc.o.NotAfter.String()) + } + } + }) + } +} + +func TestOrderSave(t *testing.T) { + type test struct { + o, old *order + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{ + o: o, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing order: force")), + } + }, + "fail/old-nil/swap-false": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{ + o: o, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil + }, + }, + err: ServerInternalErr(errors.New("error storing order; value has changed since last read")), + } + }, + "ok/old-nil": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + return test{ + o: o, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, orderTable) + assert.Equals(t, []byte(o.ID), key) + return nil, true, nil + }, + }, + } + }, + "ok/old-not-nil": func(t *testing.T) test { + oldo, err := newO() + assert.FatalError(t, err) + o, err := newO() + assert.FatalError(t, err) + + oldb, err := json.Marshal(oldo) + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + return test{ + o: o, + old: oldo, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, oldb) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, orderTable) + assert.Equals(t, []byte(o.ID), key) + return []byte("foo"), true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.o.save(tc.db, tc.old); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestNewOrder(t *testing.T) { + type test struct { + ops OrderOptions + db nosql.DB + err *Error + authzs *([]string) + } + tests := map[string]func(t *testing.T) test{ + "fail/unexpected-identifier-type": func(t *testing.T) test { + ops := defaultOrderOps() + ops.Identifiers[0].Type = "foo" + return test{ + ops: ops, + err: MalformedErr(errors.New("unexpected authz type foo")), + } + }, + "fail/save-order-error": func(t *testing.T) test { + count := 0 + return test{ + ops: defaultOrderOps(), + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count >= 6 { + return nil, false, errors.New("force") + } + count++ + return nil, true, nil + }, + }, + err: ServerInternalErr(errors.New("error storing order: force")), + } + }, + "fail/get-orderIDs-error": func(t *testing.T) test { + count := 0 + ops := defaultOrderOps() + return test{ + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count >= 7 { + return nil, false, errors.New("force") + } + count++ + return nil, true, nil + }, + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error loading orderIDs for account %s: force", ops.AccountID)), + } + }, + "fail/save-orderIDs-error": func(t *testing.T) test { + count := 0 + oids := []string{"1", "2"} + oidsB, err := json.Marshal(oids) + assert.FatalError(t, err) + var ( + _oid = "" + oid = &_oid + ) + ops := defaultOrderOps() + return test{ + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count >= 7 { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(ops.AccountID)) + return nil, false, errors.New("force") + } else if count == 6 { + *oid = string(key) + } + count++ + return nil, true, nil + }, + MGet: func(bucket, key []byte) ([]byte, error) { + return oidsB, nil + }, + MDel: func(bucket, key []byte) error { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte(*oid)) + return nil + }, + }, + err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", ops.AccountID)), + } + }, + "ok": func(t *testing.T) test { + count := 0 + oids := []string{"1", "2"} + oidsB, err := json.Marshal(oids) + assert.FatalError(t, err) + authzs := &([]string{}) + var ( + _oid = "" + oid = &_oid + ) + ops := defaultOrderOps() + return test{ + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count >= 7 { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(ops.AccountID)) + assert.Equals(t, old, oidsB) + newB, err := json.Marshal(append(oids, *oid)) + assert.FatalError(t, err) + assert.Equals(t, newval, newB) + } else if count == 6 { + *oid = string(key) + } else if count == 5 { + *authzs = append(*authzs, string(key)) + } else if count == 2 { + *authzs = []string{string(key)} + } + count++ + return nil, true, nil + }, + MGet: func(bucket, key []byte) ([]byte, error) { + return oidsB, nil + }, + }, + authzs: authzs, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + o, err := newOrder(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, o.AccountID, tc.ops.AccountID) + assert.Equals(t, o.Status, StatusPending) + assert.Equals(t, o.Identifiers, tc.ops.Identifiers) + assert.Equals(t, o.Error, nil) + assert.Equals(t, o.Certificate, "") + assert.Equals(t, o.Authorizations, *tc.authzs) + + assert.True(t, o.Created.Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, o.Created.After(time.Now().UTC().Add(-1*time.Minute))) + + expiry := o.Created.Add(defaultExpiryDuration) + assert.True(t, o.Expires.Before(expiry.Add(time.Minute))) + assert.True(t, o.Expires.After(expiry.Add(-1*time.Minute))) + + assert.Equals(t, o.NotBefore, tc.ops.NotBefore) + assert.Equals(t, o.NotAfter, tc.ops.NotAfter) + } + } + }) + } +} + +func TestOrderIDsSave(t *testing.T) { + accID := "acc-id" + newOids := func() orderIDs { + return []string{"1", "2"} + } + type test struct { + oids, old orderIDs + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + return test{ + oids: newOids(), + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", accID)), + } + }, + "fail/old-nil/swap-false": func(t *testing.T) test { + return test{ + oids: newOids(), + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil + }, + }, + err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s; order IDs changed since last read", accID)), + } + }, + "ok/old-nil": func(t *testing.T) test { + oids := newOids() + b, err := json.Marshal(oids) + assert.FatalError(t, err) + return test{ + oids: oids, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, true, nil + }, + }, + } + }, + "ok/old-not-nil": func(t *testing.T) test { + oldOids := newOids() + oids := append(oldOids, "3") + + oldb, err := json.Marshal(oldOids) + assert.FatalError(t, err) + b, err := json.Marshal(oids) + assert.FatalError(t, err) + return test{ + oids: oids, + old: oldOids, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, oldb) + assert.Equals(t, newval, b) + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.oids.save(tc.db, tc.old, accID); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestOrderUpdateStatus(t *testing.T) { + type test struct { + o, res *order + err *Error + db nosql.DB + } + tests := map[string]func(t *testing.T) test{ + "fail/already-invalid": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusInvalid + return test{ + o: o, + res: o, + } + }, + "fail/already-valid": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusValid + return test{ + o: o, + res: o, + } + }, + "fail/unexpected-status": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusDeactivated + return test{ + o: o, + res: o, + err: ServerInternalErr(errors.New("unrecognized order status: deactivated")), + } + }, + "fail/save-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Expires = time.Now().UTC().Add(-time.Minute) + return test{ + o: o, + res: o, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing order: force")), + } + }, + "ok/expired": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Expires = time.Now().UTC().Add(-time.Minute) + + _o := *o + clone := &_o + clone.Error = MalformedErr(errors.New("order has expired")) + clone.Status = StatusInvalid + return test{ + o: o, + res: clone, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + }, + } + }, + "fail/get-authz-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + return test{ + o: o, + res: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading authz")), + } + }, + "ok/still-pending": func(t *testing.T) test { + az1, err := newAz() + assert.FatalError(t, err) + az2, err := newAz() + assert.FatalError(t, err) + + ch1, err := newHTTPCh() + assert.FatalError(t, err) + ch2, err := newDNSCh() + assert.FatalError(t, err) + + ch1b, err := json.Marshal(ch1) + assert.FatalError(t, err) + ch2b, err := json.Marshal(ch2) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + o.Authorizations = []string{az1.getID(), az2.getID()} + + _az2, ok := az2.(*dnsAuthz) + assert.Fatal(t, ok) + _az2.baseAuthz.Status = StatusValid + + b1, err := json.Marshal(az1) + assert.FatalError(t, err) + b2, err := json.Marshal(az2) + assert.FatalError(t, err) + + count := 0 + return test{ + o: o, + res: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + ret = b1 + case 1: + ret = ch1b + case 2: + ret = ch2b + case 3: + ret = b2 + default: + return nil, errors.New("unexpected count") + } + count++ + return ret, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + }, + } + }, + "ok/invalid": func(t *testing.T) test { + az1, err := newAz() + assert.FatalError(t, err) + az2, err := newAz() + assert.FatalError(t, err) + + ch1, err := newHTTPCh() + assert.FatalError(t, err) + ch2, err := newDNSCh() + assert.FatalError(t, err) + + ch1b, err := json.Marshal(ch1) + assert.FatalError(t, err) + ch2b, err := json.Marshal(ch2) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + o.Authorizations = []string{az1.getID(), az2.getID()} + + _az2, ok := az2.(*dnsAuthz) + assert.Fatal(t, ok) + _az2.baseAuthz.Status = StatusInvalid + + b1, err := json.Marshal(az1) + assert.FatalError(t, err) + b2, err := json.Marshal(az2) + assert.FatalError(t, err) + + _o := *o + clone := &_o + clone.Status = StatusInvalid + + count := 0 + return test{ + o: o, + res: clone, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + ret = b1 + case 1: + ret = ch1b + case 2: + ret = ch2b + case 3: + ret = b2 + default: + return nil, errors.New("unexpected count") + } + count++ + return ret, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + o, err := tc.o.updateStatus(tc.db) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + expB, err := json.Marshal(tc.res) + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + assert.Equals(t, expB, b) + } + } + }) + } +} + +type mockSignAuth struct { + sign func(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) + loadProvisionerByID func(string) (provisioner.Interface, error) + ret1, ret2 interface{} + err error +} + +func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) { + if m.sign != nil { + return m.sign(csr, signOpts, extraOpts...) + } else if m.err != nil { + return nil, nil, m.err + } + return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err +} + +func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, error) { + if m.loadProvisionerByID != nil { + return m.loadProvisionerByID(id) + } + return m.ret1.(provisioner.Interface), m.err +} + +func TestOrderFinalize(t *testing.T) { + prov := newProv() + type test struct { + o, res *order + err *Error + db nosql.DB + csr *x509.CertificateRequest + sa SignAuthority + prov provisioner.Interface + } + tests := map[string]func(t *testing.T) test{ + "fail/already-invalid": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusInvalid + return test{ + o: o, + err: OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)), + } + }, + "ok/already-valid": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusValid + o.Certificate = "cert-id" + return test{ + o: o, + res: o, + } + }, + "fail/still-pending": func(t *testing.T) test { + az1, err := newAz() + assert.FatalError(t, err) + az2, err := newAz() + assert.FatalError(t, err) + + ch1, err := newHTTPCh() + assert.FatalError(t, err) + ch2, err := newDNSCh() + assert.FatalError(t, err) + + ch1b, err := json.Marshal(ch1) + assert.FatalError(t, err) + ch2b, err := json.Marshal(ch2) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + o.Authorizations = []string{az1.getID(), az2.getID()} + + _az2, ok := az2.(*dnsAuthz) + assert.Fatal(t, ok) + _az2.baseAuthz.Status = StatusValid + + b1, err := json.Marshal(az1) + assert.FatalError(t, err) + b2, err := json.Marshal(az2) + assert.FatalError(t, err) + + count := 0 + return test{ + o: o, + res: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + var ret []byte + switch count { + case 0: + ret = b1 + case 1: + ret = ch1b + case 2: + ret = ch2b + case 3: + ret = b2 + default: + return nil, errors.New("unexpected count") + } + count++ + return ret, nil + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, true, nil + }, + }, + err: OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)), + } + }, + "fail/ready/csr-names-match-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo", + }, + DNSNames: []string{"bar", "baz"}, + } + return test{ + o: o, + csr: csr, + err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + } + }, + "fail/ready/csr-names-match-error-2": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "acme.example.com", + }, + DNSNames: []string{"step.example.com"}, + } + return test{ + o: o, + csr: csr, + err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + } + }, + "fail/ready/provisioner-auth-sign-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo", + }, + DNSNames: []string{"step.example.com", "acme.example.com"}, + } + return test{ + o: o, + csr: csr, + err: ServerInternalErr(errors.New("error retrieving authorization options from ACME provisioner: force")), + prov: &provisioner.MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + return nil, errors.New("force") + }, + }, + } + }, + "fail/ready/sign-cert-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo", + }, + DNSNames: []string{"step.example.com", "acme.example.com"}, + } + return test{ + o: o, + csr: csr, + err: ServerInternalErr(errors.Errorf("error generating certificate for order %s: force", o.ID)), + sa: &mockSignAuth{ + err: errors.New("force"), + }, + } + }, + "fail/ready/store-cert-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo", + }, + DNSNames: []string{"step.example.com", "acme.example.com"}, + } + crt := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "acme.example.com", + }, + } + inter := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "intermediate", + }, + } + return test{ + o: o, + csr: csr, + err: ServerInternalErr(errors.Errorf("error storing certificate: force")), + sa: &mockSignAuth{ + ret1: crt, ret2: inter, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + } + }, + "fail/ready/store-order-error": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "acme", + }, + DNSNames: []string{"acme.example.com", "step.example.com"}, + } + crt := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "acme.example.com", + }, + } + inter := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "intermediate", + }, + } + count := 0 + return test{ + o: o, + csr: csr, + err: ServerInternalErr(errors.Errorf("error storing order: force")), + sa: &mockSignAuth{ + ret1: crt, ret2: inter, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 1 { + return nil, false, errors.New("force") + } + count++ + return nil, true, nil + }, + }, + } + }, + "ok/ready/sign": func(t *testing.T) test { + o, err := newO() + assert.FatalError(t, err) + o.Status = StatusReady + + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo", + }, + DNSNames: []string{"acme.example.com", "step.example.com"}, + } + crt := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "acme.example.com", + }, + } + inter := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "intermediate", + }, + } + + _o := *o + clone := &_o + clone.Status = StatusValid + + count := 0 + return test{ + o: o, + res: clone, + csr: csr, + sa: &mockSignAuth{ + sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) { + assert.Equals(t, len(signOps), 4) + return crt, inter, nil + }, + }, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count == 0 { + clone.Certificate = string(key) + } + count++ + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + p := tc.prov + if p == nil { + p = prov + } + o, err := tc.o.finalize(tc.db, tc.csr, tc.sa, p) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + expB, err := json.Marshal(tc.res) + assert.FatalError(t, err) + b, err := json.Marshal(o) + assert.FatalError(t, err) + assert.Equals(t, expB, b) + } + } + }) + } +} diff --git a/api/api.go b/api/api.go index fd091c86..3850d921 100644 --- a/api/api.go +++ b/api/api.go @@ -28,8 +28,7 @@ import ( // Authority is the interface implemented by a CA authority. type Authority interface { SSHAuthority - // NOTE: Authorize will be deprecated in future releases. Please use the - // context specific Authorize[Sign|Revoke|etc.] methods. + // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) GetTLSOptions() *tlsutil.TLSOptions @@ -37,6 +36,7 @@ type Authority interface { Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) + LoadProvisionerByID(string) (provisioner.Interface, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) Revoke(*authority.RevokeOptions) error GetEncryptedKey(kid string) (string, error) @@ -308,13 +308,12 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { return } - w.WriteHeader(http.StatusCreated) logCertificate(w, cert) - JSON(w, &SignResponse{ + JSONStatus(w, &SignResponse{ ServerPEM: Certificate{cert}, CaPEM: Certificate{root}, TLSOptions: h.Authority.GetTLSOptions(), - }) + }, http.StatusCreated) } // Renew uses the information of certificate in the TLS connection to create a @@ -331,13 +330,12 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { return } - w.WriteHeader(http.StatusCreated) logCertificate(w, cert) - JSON(w, &SignResponse{ + JSONStatus(w, &SignResponse{ ServerPEM: Certificate{cert}, CaPEM: Certificate{root}, TLSOptions: h.Authority.GetTLSOptions(), - }) + }, http.StatusCreated) } // Provisioners returns the list of provisioners configured in the authority. @@ -383,10 +381,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{roots[i]} } - w.WriteHeader(http.StatusCreated) - JSON(w, &RootsResponse{ + JSONStatus(w, &RootsResponse{ Certificates: certs, - }) + }, http.StatusCreated) } // Federation returns all the public certificates in the federation. @@ -402,10 +399,9 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{federated[i]} } - w.WriteHeader(http.StatusCreated) - JSON(w, &FederationResponse{ + JSONStatus(w, &FederationResponse{ Certificates: certs, - }) + }, http.StatusCreated) } var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1} diff --git a/api/api_test.go b/api/api_test.go index 5ece5cc9..d141247c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -506,6 +506,7 @@ type mockAuthority struct { signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) + loadProvisionerByID func(provID string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) revoke func(*authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) @@ -581,6 +582,13 @@ func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (pr return m.ret1.(provisioner.Interface), m.err } +func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { + if m.loadProvisionerByID != nil { + return m.loadProvisionerByID(provID) + } + return m.ret1.(provisioner.Interface), m.err +} + func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error { if m.revoke != nil { return m.revoke(opts) diff --git a/api/errors.go b/api/errors.go index 0e6cb939..90b41565 100644 --- a/api/errors.go +++ b/api/errors.go @@ -7,6 +7,7 @@ import ( "os" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/logging" ) @@ -109,7 +110,13 @@ func NotFound(err error) error { // WriteError writes to w a JSON representation of the given error. func WriteError(w http.ResponseWriter, err error) { - w.Header().Set("Content-Type", "application/json") + switch k := err.(type) { + case *acme.Error: + w.Header().Set("Content-Type", "application/problem+json") + err = k.ToACME() + default: + w.Header().Set("Content-Type", "application/json") + } cause := errors.Cause(err) if sc, ok := err.(StatusCoder); ok { w.WriteHeader(sc.StatusCode()) diff --git a/api/revoke.go b/api/revoke.go index 9280c980..15c42e90 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -87,8 +87,6 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } logRevoke(w, opts) - - w.WriteHeader(http.StatusOK) JSON(w, &RevokeResponse{Status: "ok"}) } diff --git a/api/utils.go b/api/utils.go index a685a673..89adedb7 100644 --- a/api/utils.go +++ b/api/utils.go @@ -10,6 +10,11 @@ import ( "github.com/smallstep/certificates/logging" ) +// EnableLogger is an interface that enables response logging for an object. +type EnableLogger interface { + ToLog() (interface{}, error) +} + // LogError adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. @@ -23,12 +28,40 @@ func LogError(rw http.ResponseWriter, err error) { } } +// LogEnabledResponse log the response object if it implements the EnableLogger +// interface. +func LogEnabledResponse(rw http.ResponseWriter, v interface{}) { + if el, ok := v.(EnableLogger); ok { + out, err := el.ToLog() + if err != nil { + LogError(rw, err) + return + } + if rl, ok := rw.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "response": out, + }) + } else { + log.Println(out) + } + } +} + // JSON writes the passed value into the http.ResponseWriter. func JSON(w http.ResponseWriter, v interface{}) { + JSONStatus(w, v, http.StatusOK) +} + +// JSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func JSONStatus(w http.ResponseWriter, v interface{}, status int) { w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) if err := json.NewEncoder(w).Encode(v); err != nil { LogError(w, err) + return } + LogEnabledResponse(w, v) } // ReadJSON reads JSON from the request body and stores it in the value diff --git a/authority/authority.go b/authority/authority.go index 848a4f63..399738d6 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -15,7 +15,9 @@ import ( "github.com/smallstep/cli/crypto/x509util" ) -const legacyAuthority = "step-certificate-authority" +const ( + legacyAuthority = "step-certificate-authority" +) // Authority implements the Certificate Authority internal interface. type Authority struct { diff --git a/authority/error.go b/authority/error.go index 85293f20..de1aa3c0 100644 --- a/authority/error.go +++ b/authority/error.go @@ -1,6 +1,8 @@ package authority import ( + "encoding/json" + "fmt" "net/http" ) @@ -33,6 +35,12 @@ func (e *apiError) Error() string { return ret } +// ErrorResponse represents an error in JSON format. +type ErrorResponse struct { + Status int `json:"status"` + Message string `json:"message"` +} + // StatusCode returns an http status code indicating the type and severity of // the error. func (e *apiError) StatusCode() int { @@ -41,3 +49,19 @@ func (e *apiError) StatusCode() int { } return e.code } + +// MarshalJSON implements json.Marshaller interface for the Error struct. +func (e *apiError) MarshalJSON() ([]byte, error) { + return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)}) +} + +// UnmarshalJSON implements json.Unmarshaler interface for the Error struct. +func (e *apiError) UnmarshalJSON(data []byte) error { + var er ErrorResponse + if err := json.Unmarshal(data, &er); err != nil { + return err + } + e.code = er.Status + e.err = fmt.Errorf(er.Message) + return nil +} diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go new file mode 100644 index 00000000..2d78e5a2 --- /dev/null +++ b/authority/provisioner/acme.go @@ -0,0 +1,85 @@ +package provisioner + +import ( + "context" + "crypto/x509" + + "github.com/pkg/errors" +) + +// ACME is the acme provisioner type, an entity that can authorize the ACME +// provisioning flow. +type ACME struct { + Type string `json:"type"` + Name string `json:"name"` + Claims *Claims `json:"claims,omitempty"` + claimer *Claimer +} + +// GetID returns the provisioner unique identifier. +func (p ACME) GetID() string { + return "acme/" + p.Name +} + +// GetTokenID returns the identifier of the token. +func (p *ACME) GetTokenID(ott string) (string, error) { + return "", errors.New("acme provisioner does not implement GetTokenID") +} + +// GetName returns the name of the provisioner. +func (p *ACME) GetName() string { + return p.Name +} + +// GetType returns the type of provisioner. +func (p *ACME) GetType() Type { + return TypeACME +} + +// GetEncryptedKey returns the base provisioner encrypted key if it's defined. +func (p *ACME) GetEncryptedKey() (string, string, bool) { + return "", "", false +} + +// Init initializes and validates the fields of a JWK type. +func (p *ACME) Init(config Config) (err error) { + switch { + case p.Type == "": + return errors.New("provisioner type cannot be empty") + case p.Name == "": + return errors.New("provisioner name cannot be empty") + } + + // Update claims with global ones + if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { + return err + } + + return err +} + +// AuthorizeRevoke is not implemented yet for the ACME provisioner. +func (p *ACME) AuthorizeRevoke(token string) error { + return nil +} + +// AuthorizeSign validates the given token. +func (p *ACME) AuthorizeSign(ctx context.Context, _ string) ([]SignOption, error) { + if m := MethodFromContext(ctx); m != SignMethod { + return nil, errors.Errorf("unexpected method type %d in context", m) + } + return []SignOption{ + profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + newProvisionerExtensionOption(TypeACME, p.Name, ""), + newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + defaultPublicKeyValidator{}, + }, nil +} + +// AuthorizeRenewal is not implemented for the ACME provisioner. +func (p *ACME) AuthorizeRenewal(cert *x509.Certificate) error { + if p.claimer.IsDisableRenewal() { + return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + } + return nil +} diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go new file mode 100644 index 00000000..c5e6b13a --- /dev/null +++ b/authority/provisioner/acme_test.go @@ -0,0 +1,184 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" +) + +func TestACME_Getters(t *testing.T) { + p, err := generateACME() + assert.FatalError(t, err) + id := "acme/" + p.Name + if got := p.GetID(); got != id { + t.Errorf("ACME.GetID() = %v, want %v", got, id) + } + if got := p.GetName(); got != p.Name { + t.Errorf("ACME.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeACME { + t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME) + } + kid, key, ok := p.GetEncryptedKey() + if kid != "" || key != "" || ok == true { + t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, "", "", false) + } +} + +func TestACME_Init(t *testing.T) { + type ProvisionerValidateTest struct { + p *ACME + err error + } + tests := map[string]func(*testing.T) ProvisionerValidateTest{ + "fail-empty": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{}, + err: errors.New("provisioner type cannot be empty"), + } + }, + "fail-empty-name": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{ + Type: "ACME", + }, + err: errors.New("provisioner name cannot be empty"), + } + }, + "fail-empty-type": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{Name: "foo"}, + err: errors.New("provisioner type cannot be empty"), + } + }, + "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}}, + err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"), + } + }, + "ok": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{Name: "foo", Type: "bar"}, + } + }, + } + + config := Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + } + for name, get := range tests { + t.Run(name, func(t *testing.T) { + tc := get(t) + err := tc.p.Init(config) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.Equals(t, tc.err.Error(), err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestACME_AuthorizeRevoke(t *testing.T) { + p, err := generateACME() + assert.FatalError(t, err) + assert.Nil(t, p.AuthorizeRevoke("")) +} + +func TestACME_AuthorizeRenewal(t *testing.T) { + p1, err := generateACME() + assert.FatalError(t, err) + p2, err := generateACME() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{DisableRenewal: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + prov *ACME + args args + err error + }{ + {"ok", p1, args{nil}, nil}, + {"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRenewal(tt.args.cert); err != nil { + if assert.NotNil(t, tt.err) { + assert.HasPrefix(t, err.Error(), tt.err.Error()) + } + } else { + assert.Nil(t, tt.err) + } + }) + } +} + +func TestACME_AuthorizeSign(t *testing.T) { + p1, err := generateACME() + assert.FatalError(t, err) + + tests := []struct { + name string + prov *ACME + method Method + err error + }{ + {"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")}, + {"ok", p1, SignMethod, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), tt.method) + if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil { + if assert.NotNil(t, tt.err) { + assert.HasPrefix(t, err.Error(), tt.err.Error()) + } + } else { + if assert.NotNil(t, got) { + assert.Len(t, 4, got) + + _pdd := got[0] + pdd, ok := _pdd.(profileDefaultDuration) + assert.True(t, ok) + assert.Equals(t, pdd, profileDefaultDuration(86400000000000)) + + _peo := got[1] + peo, ok := _peo.(*provisionerExtensionOption) + assert.True(t, ok) + assert.Equals(t, peo.Type, 6) + assert.Equals(t, peo.Name, "test@acme-provisioner.com") + assert.Equals(t, peo.CredentialID, "") + assert.Equals(t, peo.KeyValuePairs, nil) + + _vv := got[2] + vv, ok := _vv.(*validityValidator) + assert.True(t, ok) + assert.Equals(t, vv.min, time.Duration(300000000000)) + assert.Equals(t, vv.max, time.Duration(86400000000000)) + + _dpkv := got[3] + _, ok = _dpkv.(defaultPublicKeyValidator) + assert.True(t, ok) + } + } + }) + } +} diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index ca7a5391..f674bae8 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -127,6 +127,8 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) return c.Load("aws/" + string(provisioner.Name)) case TypeGCP: return c.Load("gcp/" + string(provisioner.Name)) + case TypeACME: + return c.Load("acme/" + string(provisioner.Name)) default: return c.Load(string(provisioner.CredentialID)) } @@ -152,8 +154,9 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) { // Store adds a provisioner to the collection and enforces the uniqueness of // provisioner IDs. func (c *Collection) Store(p Interface) error { + fmt.Printf("p.GetID() = %+v\n", p.GetID()) // Store provisioner always in byID. ID must be unique. - if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true { + if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded { return errors.New("cannot add multiple provisioners with the same id") } diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index 55e0f056..b06a27c5 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -133,15 +133,20 @@ func TestCollection_LoadByCertificate(t *testing.T) { assert.FatalError(t, err) p2, err := generateOIDC() assert.FatalError(t, err) + p3, err := generateACME() + assert.FatalError(t, err) byID := new(sync.Map) byID.Store(p1.GetID(), p1) byID.Store(p2.GetID(), p2) + byID.Store(p3.GetID(), p3) ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) assert.FatalError(t, err) ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) assert.FatalError(t, err) + ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "") + assert.FatalError(t, err) notFoundExt, err := createProvisionerExtension(1, "foo", "bar") assert.FatalError(t, err) @@ -151,6 +156,9 @@ func TestCollection_LoadByCertificate(t *testing.T) { ok2Cert := &x509.Certificate{ Extensions: []pkix.Extension{ok2Ext}, } + ok3Cert := &x509.Certificate{ + Extensions: []pkix.Extension{ok3Ext}, + } notFoundCert := &x509.Certificate{ Extensions: []pkix.Extension{notFoundExt}, } @@ -176,6 +184,7 @@ func TestCollection_LoadByCertificate(t *testing.T) { }{ {"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true}, {"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true}, + {"ok3", fields{byID, testAudiences}, args{ok3Cert}, p3, true}, {"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, {"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false}, {"badCert", fields{byID, testAudiences}, args{badCert}, nil, false}, diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 248b93cf..2a63161c 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -84,6 +84,8 @@ const ( TypeAWS Type = 4 // TypeAzure is used to indicate the Azure provisioners. TypeAzure Type = 5 + // TypeACME is used to indicate the ACME provisioners. + TypeACME Type = 6 // RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map. RevokeAudienceKey = "revoke" @@ -104,6 +106,8 @@ func (t Type) String() string { return "AWS" case TypeAzure: return "Azure" + case TypeACME: + return "ACME" default: return "" } @@ -151,6 +155,8 @@ func (l *List) UnmarshalJSON(data []byte) error { p = &AWS{} case "azure": p = &Azure{} + case "acme": + p = &ACME{} default: // Skip unsupported provisioners. A client using this method may be // compiled with a version of smallstep/certificates that does not @@ -197,3 +203,93 @@ func SanitizeSSHUserPrincipal(email string) string { } }, strings.ToLower(email)) } + +// MockProvisioner for testing +type MockProvisioner struct { + Mret1, Mret2, Mret3 interface{} + Merr error + MgetID func() string + MgetTokenID func(string) (string, error) + MgetName func() string + MgetType func() Type + MgetEncryptedKey func() (string, string, bool) + Minit func(Config) error + MauthorizeRevoke func(ott string) error + MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error) + MauthorizeRenewal func(*x509.Certificate) error +} + +// GetID mock +func (m *MockProvisioner) GetID() string { + if m.MgetID != nil { + return m.MgetID() + } + return m.Mret1.(string) +} + +// GetTokenID mock +func (m *MockProvisioner) GetTokenID(token string) (string, error) { + if m.MgetTokenID != nil { + return m.MgetTokenID(token) + } + if m.Mret1 == nil { + return "", m.Merr + } + return m.Mret1.(string), m.Merr +} + +// GetName mock +func (m *MockProvisioner) GetName() string { + if m.MgetName != nil { + return m.MgetName() + } + return m.Mret1.(string) +} + +// GetType mock +func (m *MockProvisioner) GetType() Type { + if m.MgetType != nil { + return m.MgetType() + } + return m.Mret1.(Type) +} + +// GetEncryptedKey mock +func (m *MockProvisioner) GetEncryptedKey() (string, string, bool) { + if m.MgetEncryptedKey != nil { + return m.MgetEncryptedKey() + } + return m.Mret1.(string), m.Mret2.(string), m.Mret3.(bool) +} + +// Init mock +func (m *MockProvisioner) Init(c Config) error { + if m.Minit != nil { + return m.Minit(c) + } + return m.Merr +} + +// AuthorizeRevoke mock +func (m *MockProvisioner) AuthorizeRevoke(ott string) error { + if m.MauthorizeRevoke != nil { + return m.MauthorizeRevoke(ott) + } + return m.Merr +} + +// AuthorizeSign mock +func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) { + if m.MauthorizeSign != nil { + return m.MauthorizeSign(ctx, ott) + } + return m.Mret1.([]SignOption), m.Merr +} + +// AuthorizeRenewal mock +func (m *MockProvisioner) AuthorizeRenewal(c *x509.Certificate) error { + if m.MauthorizeRenewal != nil { + return m.MauthorizeRenewal(c) + } + return m.Merr +} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 134cabe4..2fc4a2b9 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/cli/crypto/keys" "golang.org/x/crypto/ssh" ) @@ -298,8 +299,9 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error { if err != nil { return err } - if key.Size() < 256 { - return errors.New("ssh certificate key must be at least 2048 bits (256 bytes)") + if key.Size() < keys.MinRSAKeyBytes { + return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)", + 8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes) } return nil case ssh.KeyAlgoDSA: diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index c6e820ed..91e67f02 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -730,3 +730,15 @@ func generateJWKServer(n int) *httptest.Server { srv.Start() return srv } + +func generateACME() (*ACME, error) { + // Initialize provisioners + p := &ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(Config{Claims: globalProvisionerClaims}); err != nil { + return nil, err + } + return p, nil +} diff --git a/authority/provisioners.go b/authority/provisioners.go index 5328eb4d..2d43571b 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -35,3 +35,13 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi } return p, nil } + +// LoadProvisionerByID returns an interface to the provisioner with the given ID. +func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + p, ok := a.provisioners.Load(id) + if !ok { + return nil, &apiError{errors.Errorf("provisioner not found"), + http.StatusNotFound, apiCtx{}} + } + return p, nil +} diff --git a/authority/testdata/certs/badsig.csr b/authority/testdata/certs/badsig.csr new file mode 100644 index 00000000..976e6a5d --- /dev/null +++ b/authority/testdata/certs/badsig.csr @@ -0,0 +1,8 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI +cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ +DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ +ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5 +OI+cWOIc/IGwqZul/zEF5dani5ihOL7UwA== +-----END CERTIFICATE REQUEST----- diff --git a/authority/testdata/certs/foo.csr b/authority/testdata/certs/foo.csr new file mode 100644 index 00000000..839dd0b1 --- /dev/null +++ b/authority/testdata/certs/foo.csr @@ -0,0 +1,8 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI +cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ +DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ +ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5 +OI+cWOIc/IGwqZul/zEF5dani5ihOR7UwA== +-----END CERTIFICATE REQUEST----- diff --git a/ca/acmeClient.go b/ca/acmeClient.go new file mode 100644 index 00000000..3895381f --- /dev/null +++ b/ca/acmeClient.go @@ -0,0 +1,354 @@ +package ca + +import ( + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + acmeAPI "github.com/smallstep/certificates/acme/api" + "github.com/smallstep/cli/jose" +) + +// ACMEClient implements an HTTP client to an ACME API. +type ACMEClient struct { + client *http.Client + dirLoc string + dir *acme.Directory + acc *acme.Account + Key *jose.JSONWebKey + kid string +} + +// NewACMEClient initializes a new ACMEClient. +func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) { + // Retrieve transport from options. + o := new(clientOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + tr, err := o.getTransport(endpoint) + if err != nil { + return nil, err + } + + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: endpoint, + } + + resp, err := ac.client.Get(endpoint) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", endpoint) + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + var dir acme.Directory + if err := readJSON(resp.Body, &dir); err != nil { + return nil, errors.Wrapf(err, "error reading %s", endpoint) + } + + ac.dir = &dir + + ac.Key, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + if err != nil { + return nil, err + } + + nar := &acmeAPI.NewAccountRequest{ + Contact: contact, + TermsOfServiceAgreed: true, + } + payload, err := json.Marshal(nar) + if err != nil { + return nil, errors.Wrap(err, "error marshaling new account request") + } + + resp, err = ac.post(payload, ac.dir.NewAccount, withJWK(ac)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + var acc acme.Account + if err := readJSON(resp.Body, &acc); err != nil { + return nil, errors.Wrapf(err, "error reading %s", dir.NewAccount) + } + ac.acc = &acc + ac.kid = resp.Header.Get("Location") + + return ac, nil +} + +// GetDirectory makes a directory request to the ACME api and returns an +// ACME directory object. +func (c *ACMEClient) GetDirectory() (*acme.Directory, error) { + return c.dir, nil +} + +// GetNonce makes a nonce request to the ACME api and returns an +// ACME directory object. +func (c *ACMEClient) GetNonce() (string, error) { + resp, err := c.client.Get(c.dir.NewNonce) + if err != nil { + return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce) + } + if resp.StatusCode >= 400 { + return "", readACMEError(resp.Body) + } + return resp.Header.Get("Replay-Nonce"), nil +} + +type withHeaderOption func(so *jose.SignerOptions) + +func withJWK(c *ACMEClient) withHeaderOption { + return func(so *jose.SignerOptions) { + so.WithHeader("jwk", c.Key.Public()) + } +} + +func withKid(c *ACMEClient) withHeaderOption { + return func(so *jose.SignerOptions) { + so.WithHeader("kid", c.kid) + } +} + +// serialize serializes a json web signature and doesn't omit empty fields. +func serialize(obj *jose.JSONWebSignature) (string, error) { + raw, err := obj.CompactSerialize() + if err != nil { + return "", errors.Wrap(err, "error serializing JWS") + } + parts := strings.Split(raw, ".") + msg := struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Signature string `json:"signature"` + }{Protected: parts[0], Payload: parts[1], Signature: parts[2]} + b, err := json.Marshal(msg) + if err != nil { + return "", errors.Wrap(err, "error marshaling jws message") + } + return string(b), nil +} + +func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOption) (*http.Response, error) { + if c.Key == nil { + return nil, errors.New("acme client not configured with account") + } + nonce, err := c.GetNonce() + if err != nil { + return nil, err + } + so := new(jose.SignerOptions) + so.WithHeader("nonce", nonce) + so.WithHeader("url", url) + for _, hop := range headerOps { + hop(so) + } + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(c.Key.Algorithm), + Key: c.Key.Key, + }, so) + if err != nil { + return nil, errors.Wrap(err, "error creating JWS signer") + } + signed, err := signer.Sign(payload) + if err != nil { + return nil, errors.Errorf("error signing payload: %s", strings.TrimPrefix(err.Error(), "square/go-jose: ")) + } + raw, err := serialize(signed) + if err != nil { + return nil, err + } + resp, err := c.client.Post(url, "application/jose+json", strings.NewReader(raw)) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", c.dir.NewOrder) + } + return resp, nil +} + +// NewOrder creates and returns the information for a new ACME order. +func (c *ACMEClient) NewOrder(payload []byte) (*acme.Order, error) { + resp, err := c.post(payload, c.dir.NewOrder, withKid(c)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + + var o acme.Order + if err := readJSON(resp.Body, &o); err != nil { + return nil, errors.Wrapf(err, "error reading %s", c.dir.NewOrder) + } + o.ID = resp.Header.Get("Location") + return &o, nil +} + +// GetChallenge returns the Challenge at the given path. +// With the validate parameter set to True this method will attempt to validate the +// challenge before returning it. +func (c *ACMEClient) GetChallenge(url string) (*acme.Challenge, error) { + resp, err := c.post(nil, url, withKid(c)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + + var ch acme.Challenge + if err := readJSON(resp.Body, &ch); err != nil { + return nil, errors.Wrapf(err, "error reading %s", url) + } + return &ch, nil +} + +// ValidateChallenge returns the Challenge at the given path. +// With the validate parameter set to True this method will attempt to validate the +// challenge before returning it. +func (c *ACMEClient) ValidateChallenge(url string) error { + resp, err := c.post([]byte("{}"), url, withKid(c)) + if err != nil { + return err + } + if resp.StatusCode >= 400 { + return readACMEError(resp.Body) + } + return nil +} + +// GetAuthz returns the Authz at the given path. +func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { + resp, err := c.post(nil, url, withKid(c)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + + var az acme.Authz + if err := readJSON(resp.Body, &az); err != nil { + return nil, errors.Wrapf(err, "error reading %s", url) + } + return &az, nil +} + +// GetOrder returns the Order at the given path. +func (c *ACMEClient) GetOrder(url string) (*acme.Order, error) { + resp, err := c.post(nil, url, withKid(c)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + + var o acme.Order + if err := readJSON(resp.Body, &o); err != nil { + return nil, errors.Wrapf(err, "error reading %s", url) + } + return &o, nil +} + +// FinalizeOrder makes a finalize request to the ACME api. +func (c *ACMEClient) FinalizeOrder(url string, csr *x509.CertificateRequest) error { + payload, err := json.Marshal(acmeAPI.FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + }) + if err != nil { + return errors.Wrap(err, "error marshaling finalize request") + } + resp, err := c.post(payload, url, withKid(c)) + if err != nil { + return err + } + if resp.StatusCode >= 400 { + return readACMEError(resp.Body) + } + return nil +} + +// GetCertificate retrieves the certificate along with all intermediates. +func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Certificate, error) { + resp, err := c.post(nil, url, withKid(c)) + if err != nil { + return nil, nil, err + } + if resp.StatusCode >= 400 { + return nil, nil, readACMEError(resp.Body) + } + defer resp.Body.Close() + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, nil, errors.Wrap(err, "error reading GET certificate response") + } + + var certs []*x509.Certificate + + block, rest := pem.Decode(bodyBytes) + if block == nil { + return nil, nil, errors.New("failed to parse any certificates from response") + } + for block != nil { + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(err, "error parsing certificate pem response") + } + certs = append(certs, cert) + block, rest = pem.Decode(rest) + } + + return certs[0], certs[1:], nil +} + +// GetAccountOrders retrieves the orders belonging to the given account. +func (c *ACMEClient) GetAccountOrders() ([]string, error) { + if c.acc == nil { + return nil, errors.New("acme client not configured with account") + } + resp, err := c.post(nil, c.acc.Orders, withKid(c)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, readACMEError(resp.Body) + } + + var orders []string + if err := readJSON(resp.Body, &orders); err != nil { + return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders) + } + + return orders, nil +} + +func readACMEError(r io.ReadCloser) error { + defer r.Close() + b, err := ioutil.ReadAll(r) + if err != nil { + return errors.Wrap(err, "error reading from body") + } + ae := new(acme.AError) + err = json.Unmarshal(b, &ae) + // If we successfully marshaled to an ACMEError then return the ACMEError. + if err != nil || len(ae.Error()) == 0 { + fmt.Printf("b = %s\n", b) + // Throw up our hands. + return errors.Errorf("%s", b) + } + return ae +} diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go new file mode 100644 index 00000000..d4218cd8 --- /dev/null +++ b/ca/acmeClient_test.go @@ -0,0 +1,1358 @@ +package ca + +import ( + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + acmeAPI "github.com/smallstep/certificates/acme/api" + "github.com/smallstep/certificates/api" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/jose" +) + +func TestNewACMEClient(t *testing.T) { + type test struct { + endpoint string + ops []ClientOption + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + NewAccount: srv.URL + "/bar", + NewOrder: srv.URL + "/baz", + NewAuthz: srv.URL + "/zap", + RevokeCert: srv.URL + "/zip", + KeyChange: srv.URL + "/blorp", + } + acc := acme.Account{ + Contact: []string{"max", "mariano"}, + Status: "valid", + Orders: "orders-url", + } + tests := map[string]func(t *testing.T) test{ + "fail/client-option-error": func(t *testing.T) test { + return test{ + ops: []ClientOption{ + func(o *clientOptions) error { + return errors.New("force") + }, + }, + err: errors.New("force"), + } + }, + "fail/get-directory": func(t *testing.T) test { + return test{ + ops: []ClientOption{WithTransport(http.DefaultTransport)}, + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-directory": func(t *testing.T) test { + return test{ + ops: []ClientOption{WithTransport(http.DefaultTransport)}, + r1: "foo", + rc1: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "fail/error-post-newAccount": func(t *testing.T) test { + return test{ + ops: []ClientOption{WithTransport(http.DefaultTransport)}, + r1: dir, + rc1: 200, + r2: acme.AccountDoesNotExistErr(nil).ToACME(), + rc2: 400, + err: errors.New("Account does not exist"), + } + }, + "fail/error-bad-account": func(t *testing.T) test { + return test{ + ops: []ClientOption{WithTransport(http.DefaultTransport)}, + r1: dir, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + ops: []ClientOption{WithTransport(http.DefaultTransport)}, + r1: dir, + rc1: 200, + r2: acc, + rc2: 200, + } + }, + } + + accLocation := "linkitylink" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + switch { + case i == 0: + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + case i == 1: + w.Header().Set("Replay-Nonce", "abc123") + api.JSONStatus(w, []byte{}, 200) + i++ + default: + w.Header().Set("Location", accLocation) + api.JSONStatus(w, tc.r2, tc.rc2) + } + }) + + if client, err := NewACMEClient(srv.URL, []string{"max", "mariano"}, tc.ops...); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, *client.dir, dir) + assert.NotNil(t, client.Key) + assert.NotNil(t, client.acc) + assert.Equals(t, client.kid, accLocation) + } + } + }) + } +} + +func TestACMEClient_GetDirectory(t *testing.T) { + c := &ACMEClient{ + dir: &acme.Directory{ + NewNonce: "/foo", + NewAccount: "/bar", + NewOrder: "/baz", + NewAuthz: "/zap", + RevokeCert: "/zip", + KeyChange: "/blorp", + }, + } + dir, err := c.GetDirectory() + assert.FatalError(t, err) + assert.Equals(t, c.dir, dir) +} + +func TestACMEClient_GetNonce(t *testing.T) { + type test struct { + r1 interface{} + rc1 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + } + + tests := map[string]func(t *testing.T) test{ + "fail/GET-nonce": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + } + }, + } + + expectedNonce := "abc123" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + api.JSONStatus(w, tc.r1, tc.rc1) + }) + + if nonce, err := ac.GetNonce(); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, expectedNonce, nonce) + } + } + }) + } +} + +func TestACMEClient_post(t *testing.T) { + type test struct { + payload []byte + Key *jose.JSONWebKey + ops []withHeaderOption + r1, r2 interface{} + rc1, rc2 int + jwkInJWS bool + client *ACMEClient + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := acme.Account{ + Contact: []string{"max", "mariano"}, + Status: "valid", + Orders: "orders-url", + } + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/account-not-configured": func(t *testing.T) test { + return test{ + client: &ACMEClient{}, + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("acme client not configured with account"), + } + }, + "fail/GET-nonce": func(t *testing.T) test { + return test{ + client: ac, + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "ok/jwk": func(t *testing.T) test { + return test{ + client: ac, + r1: []byte{}, + rc1: 200, + r2: acc, + rc2: 200, + ops: []withHeaderOption{withJWK(ac)}, + jwkInJWS: true, + } + }, + "ok/kid": func(t *testing.T) test { + return test{ + client: ac, + r1: []byte{}, + rc1: 200, + r2: acc, + rc2: 200, + ops: []withHeaderOption{withKid(ac)}, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/foo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + + if tc.jwkInJWS { + assert.Equals(t, hdr.JSONWebKey.KeyID, ac.Key.KeyID) + } else { + assert.Equals(t, hdr.KeyID, ac.kid) + } + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + var res acme.Account + assert.FatalError(t, readJSON(resp.Body, &res)) + assert.Equals(t, res, acc) + } + } + }) + } +} + +func TestACMEClient_NewOrder(t *testing.T) { + type test struct { + payload []byte + jwk *jose.JSONWebKey + ops []withHeaderOption + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + NewOrder: srv.URL + "/bar", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + nor := acmeAPI.NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "dns", Value: "acme.example.com"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute), + } + norb, err := json.Marshal(nor) + assert.FatalError(t, err) + ord := acme.Order{ + Status: "valid", + Expires: "soon", + Finalize: "finalize-url", + } + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/newOrder-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + ops: []withHeaderOption{withKid(ac)}, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-order": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + ops: []withHeaderOption{withKid(ac)}, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: ord, + rc2: 200, + ops: []withHeaderOption{withKid(ac)}, + } + }, + } + + expectedNonce := "abc123" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, dir.NewOrder) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + assert.Equals(t, payload, norb) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if res, err := ac.NewOrder(norb); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, *res, ord) + } + } + }) + } +} + +func TestACMEClient_GetOrder(t *testing.T) { + type test struct { + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ord := acme.Order{ + Status: "valid", + Expires: "soon", + Finalize: "finalize-url", + } + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/getOrder-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-order": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: ord, + rc2: 200, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/hullaballoo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + assert.Equals(t, len(payload), 0) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if res, err := ac.GetOrder(url); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, *res, ord) + } + } + }) + } +} + +func TestACMEClient_GetAuthz(t *testing.T) { + type test struct { + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + az := acme.Authz{ + Status: "valid", + Expires: "soon", + Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, + } + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/getChallenge-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-challenge": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: az, + rc2: 200, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/hullaballoo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + assert.Equals(t, len(payload), 0) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if res, err := ac.GetAuthz(url); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, *res, az) + } + } + }) + } +} + +func TestACMEClient_GetChallenge(t *testing.T) { + type test struct { + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ch := acme.Challenge{ + Type: "http-01", + Status: "valid", + Token: "foo", + } + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/getChallenge-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-challenge": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: ch, + rc2: 200, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/hullaballoo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + + assert.Equals(t, len(payload), 0) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if res, err := ac.GetChallenge(url); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, *res, ch) + } + } + }) + } +} + +func TestACMEClient_ValidateChallenge(t *testing.T) { + type test struct { + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ch := acme.Challenge{ + Type: "http-01", + Status: "valid", + Token: "foo", + } + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/getChallenge-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-challenge": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: ch, + rc2: 200, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/hullaballoo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + + assert.Equals(t, payload, []byte("{}")) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if err := ac.ValidateChallenge(url); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + }) + } +} + +func TestACMEClient_FinalizeOrder(t *testing.T) { + type test struct { + r1, r2 interface{} + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + ord := acme.Order{ + Status: "valid", + Expires: "soon", + Finalize: "finalize-url", + Certificate: "cert-url", + } + _csr, err := pemutil.Read("../authority/testdata/certs/foo.csr") + assert.FatalError(t, err) + csr, ok := _csr.(*x509.CertificateRequest) + assert.Fatal(t, ok) + fr := acmeAPI.FinalizeRequest{CSR: base64.RawURLEncoding.EncodeToString(csr.Raw)} + frb, err := json.Marshal(fr) + assert.FatalError(t, err) + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/finalizeOrder-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-order": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: ord, + rc2: 200, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/hullaballoo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + assert.Equals(t, payload, frb) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if err := ac.FinalizeOrder(url, csr); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + }) + } +} + +func TestACMEClient_GetAccountOrders(t *testing.T) { + type test struct { + r1, r2 interface{} + rc1, rc2 int + err error + client *ACMEClient + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + orders := []string{"foo", "bar", "baz"} + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + acc: &acme.Account{ + Contact: []string{"max", "mariano"}, + Status: "valid", + Orders: srv.URL + "/orders-url", + }, + } + + tests := map[string]func(t *testing.T) test{ + "fail/account-not-configured": func(t *testing.T) test { + return test{ + client: &ACMEClient{}, + err: errors.New("acme client not configured with account"), + } + }, + "fail/client-post": func(t *testing.T) test { + return test{ + client: ac, + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/getAccountOrders-error": func(t *testing.T) test { + return test{ + client: ac, + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-accountOrders": func(t *testing.T) test { + return test{ + client: ac, + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("error reading http://127.0.0.1"), + } + }, + "ok": func(t *testing.T) test { + return test{ + client: ac, + r1: []byte{}, + rc1: 200, + r2: orders, + rc2: 200, + } + }, + } + + expectedNonce := "abc123" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, ac.acc.Orders) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + assert.Equals(t, len(payload), 0) + + api.JSONStatus(w, tc.r2, tc.rc2) + }) + + if res, err := tc.client.GetAccountOrders(); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, res, orders) + } + } + }) + } +} + +func TestACMEClient_GetCertificate(t *testing.T) { + type test struct { + r1, r2 interface{} + certBytes []byte + rc1, rc2 int + err error + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + dir := acme.Directory{ + NewNonce: srv.URL + "/foo", + } + // Retrieve transport from options. + o := new(clientOptions) + assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) + tr, err := o.getTransport(srv.URL) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + leaf, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + leafb := pem.EncodeToMemory(&pem.Block{ + Type: "Certificate", + Bytes: leaf.Raw, + }) + certBytes := append(leafb, leafb...) + certBytes = append(certBytes, leafb...) + ac := &ACMEClient{ + client: &http.Client{ + Transport: tr, + }, + dirLoc: srv.URL, + dir: &dir, + Key: jwk, + kid: "foobar", + acc: &acme.Account{ + Contact: []string{"max", "mariano"}, + Status: "valid", + Orders: srv.URL + "/orders-url", + }, + } + + tests := map[string]func(t *testing.T) test{ + "fail/client-post": func(t *testing.T) test { + return test{ + r1: acme.MalformedErr(nil).ToACME(), + rc1: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/getAccountOrders-error": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: acme.MalformedErr(nil).ToACME(), + rc2: 400, + err: errors.New("The request message was malformed"), + } + }, + "fail/bad-certificate": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + r2: "foo", + rc2: 200, + err: errors.New("failed to parse any certificates from response"), + } + }, + "ok": func(t *testing.T) test { + return test{ + r1: []byte{}, + rc1: 200, + certBytes: certBytes, + } + }, + } + + expectedNonce := "abc123" + url := srv.URL + "/cert/foo" + + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + i := 0 + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Replay-Nonce", expectedNonce) + if i == 0 { + api.JSONStatus(w, tc.r1, tc.rc1) + i++ + return + } + + // validate jws request protected headers and body + body, err := ioutil.ReadAll(req.Body) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(body)) + assert.FatalError(t, err) + hdr := jws.Signatures[0].Protected + + assert.Equals(t, hdr.Nonce, expectedNonce) + jwsURL, ok := hdr.ExtraHeaders["url"].(string) + assert.Fatal(t, ok) + assert.Equals(t, jwsURL, url) + assert.Equals(t, hdr.KeyID, ac.kid) + + payload, err := jws.Verify(ac.Key.Public()) + assert.FatalError(t, err) + assert.Equals(t, len(payload), 0) + + if tc.certBytes != nil { + w.Write(tc.certBytes) + } else { + api.JSONStatus(w, tc.r2, tc.rc2) + } + }) + + if crt, chain, err := ac.GetCertificate(url); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, crt, leaf) + assert.Equals(t, chain, []*x509.Certificate{leaf, leaf}) + } + } + }) + } +} diff --git a/ca/ca.go b/ca/ca.go index 06f36975..5b706b09 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -3,18 +3,23 @@ package ca import ( "crypto/tls" "crypto/x509" + "fmt" "log" "net/http" + "net/url" "reflect" "github.com/go-chi/chi" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + acmeAPI "github.com/smallstep/certificates/acme/api" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/server" + "github.com/smallstep/nosql" ) type options struct { @@ -58,11 +63,12 @@ func WithDatabase(db db.AuthDB) Option { // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { - auth *authority.Authority - config *authority.Config - srv *server.Server - opts *options - renewer *TLSRenewer + auth *authority.Authority + acmeAuth *acme.Authority + config *authority.Config + srv *server.Server + opts *options + renewer *TLSRenewer } // New creates and initializes the CA with the given configuration and options. @@ -100,13 +106,47 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { mux := chi.NewRouter() handler := http.Handler(mux) - // Add api endpoints in / and /1.0 + // Add regular CA api endpoints in / and /1.0 routerHandler := api.New(auth) routerHandler.Route(mux) mux.Route("/1.0", func(r chi.Router) { routerHandler.Route(r) }) + //Add ACME api endpoints in /acme and /1.0/acme + dns := config.DNSNames[0] + u, err := url.Parse("https://" + config.Address) + if err != nil { + return nil, err + } + port := u.Port() + if port != "" && port != "443" { + dns = fmt.Sprintf("%s:%s", dns, port) + } + + prefix := "acme" + acmeAuth := acme.NewAuthority(auth.GetDatabase().(nosql.DB), dns, prefix, auth) + acmeRouterHandler := acmeAPI.New(acmeAuth) + mux.Route("/"+prefix, func(r chi.Router) { + acmeRouterHandler.Route(r) + }) + // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 + // of the ACME spec. + mux.Route("/2.0/"+prefix, func(r chi.Router) { + acmeRouterHandler.Route(r) + }) + + /* + // helpful routine for logging all routes // + walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { + fmt.Printf("%s %s\n", method, route) + return nil + } + if err := chi.Walk(mux, walkFunc); err != nil { + fmt.Printf("Logging err: %s\n", err.Error()) + } + */ + // Add monitoring if configured if len(config.Monitoring) > 0 { m, err := monitoring.New(config.Monitoring) diff --git a/ca/client_test.go b/ca/client_test.go index 5fb9aae4..1c90e52b 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -163,8 +163,7 @@ func TestClient_Health(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() @@ -224,8 +223,7 @@ func TestClient_Root(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) @@ -303,8 +301,7 @@ func TestClient_Sign(t *testing.T) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) @@ -378,8 +375,7 @@ func TestClient_Revoke(t *testing.T) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) @@ -438,8 +434,7 @@ func TestClient_Renew(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) @@ -502,8 +497,7 @@ func TestClient_Provisioners(t *testing.T) { if req.RequestURI != tt.expectedURI { t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) } - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) @@ -562,8 +556,7 @@ func TestClient_ProvisionerKey(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) @@ -622,8 +615,7 @@ func TestClient_Roots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() @@ -683,8 +675,7 @@ func TestClient_Federation(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() @@ -783,8 +774,7 @@ func TestClient_RootFingerprint(t *testing.T) { } tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) + api.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() diff --git a/db/db.go b/db/db.go index 3438046c..4d9f5ae0 100644 --- a/db/db.go +++ b/db/db.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" ) var ( @@ -102,22 +103,20 @@ func (db *DB) IsRevoked(sn string) (bool, error) { // Revoke adds a certificate to the revocation table. func (db *DB) Revoke(rci *RevokedCertificateInfo) error { - isRvkd, err := db.IsRevoked(rci.Serial) - if err != nil { - return err - } - if isRvkd { - return ErrAlreadyExists - } rcib, err := json.Marshal(rci) if err != nil { return errors.Wrap(err, "error marshaling revoked certificate info") } - if err = db.Set(revokedCertsTable, []byte(rci.Serial), rcib); err != nil { - return errors.Wrap(err, "database Set error") + _, swapped, err := db.CmpAndSwap(revokedCertsTable, []byte(rci.Serial), nil, rcib) + switch { + case err != nil: + return errors.Wrap(err, "error AuthDB CmpAndSwap") + case !swapped: + return ErrAlreadyExists + default: + return nil } - return nil } // StoreCertificate stores a certificate PEM. @@ -132,15 +131,11 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { _, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok)) - switch { - case err != nil: + if err != nil { return false, errors.Wrapf(err, "error storing used token %s/%s", string(usedOTTTable), id) - case !swapped: - return false, nil - default: - return true, nil } + return swapped, nil } // Shutdown sends a shutdown message to the database. @@ -153,3 +148,105 @@ func (db *DB) Shutdown() error { } return nil } + +// MockNoSQLDB // +type MockNoSQLDB struct { + Err error + Ret1, Ret2 interface{} + MGet func(bucket, key []byte) ([]byte, error) + MSet func(bucket, key, value []byte) error + MOpen func(dataSourceName string, opt ...database.Option) error + MClose func() error + MCreateTable func(bucket []byte) error + MDeleteTable func(bucket []byte) error + MDel func(bucket, key []byte) error + MList func(bucket []byte) ([]*database.Entry, error) + MUpdate func(tx *database.Tx) error + MCmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error) +} + +// CmpAndSwap mock +func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) { + if m.MCmpAndSwap != nil { + return m.MCmpAndSwap(bucket, key, old, newval) + } + if m.Ret1 == nil { + return nil, false, m.Err + } + return m.Ret1.([]byte), m.Ret2.(bool), m.Err +} + +// Get mock +func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) { + if m.MGet != nil { + return m.MGet(bucket, key) + } + if m.Ret1 == nil { + return nil, m.Err + } + return m.Ret1.([]byte), m.Err +} + +// Set mock +func (m *MockNoSQLDB) Set(bucket, key, value []byte) error { + if m.MSet != nil { + return m.MSet(bucket, key, value) + } + return m.Err +} + +// Open mock +func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error { + if m.MOpen != nil { + return m.MOpen(dataSourceName, opt...) + } + return m.Err +} + +// Close mock +func (m *MockNoSQLDB) Close() error { + if m.MClose != nil { + return m.MClose() + } + return m.Err +} + +// CreateTable mock +func (m *MockNoSQLDB) CreateTable(bucket []byte) error { + if m.MCreateTable != nil { + return m.MCreateTable(bucket) + } + return m.Err +} + +// DeleteTable mock +func (m *MockNoSQLDB) DeleteTable(bucket []byte) error { + if m.MDeleteTable != nil { + return m.MDeleteTable(bucket) + } + return m.Err +} + +// Del mock +func (m *MockNoSQLDB) Del(bucket, key []byte) error { + if m.MDel != nil { + return m.MDel(bucket, key) + } + return m.Err +} + +// List mock +func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) { + if m.MList != nil { + return m.MList(bucket) + } + return m.Ret1.([]*database.Entry), m.Err +} + +// Update mock +func (m *MockNoSQLDB) Update(tx *database.Tx) error { + if m.MUpdate != nil { + return m.MUpdate(tx) + } + return m.Err +} diff --git a/db/db_test.go b/db/db_test.go index a486fd84..7efc623e 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -8,97 +8,6 @@ import ( "github.com/smallstep/nosql/database" ) -type MockNoSQLDB struct { - err error - ret1, ret2 interface{} - get func(bucket, key []byte) ([]byte, error) - set func(bucket, key, value []byte) error - open func(dataSourceName string, opt ...database.Option) error - close func() error - createTable func(bucket []byte) error - deleteTable func(bucket []byte) error - del func(bucket, key []byte) error - list func(bucket []byte) ([]*database.Entry, error) - update func(tx *database.Tx) error - cmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error) -} - -func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) { - if m.cmpAndSwap != nil { - return m.cmpAndSwap(bucket, key, old, newval) - } - if m.ret1 == nil { - return nil, false, m.err - } - return m.ret1.([]byte), m.ret2.(bool), m.err -} - -func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) { - if m.get != nil { - return m.get(bucket, key) - } - if m.ret1 == nil { - return nil, m.err - } - return m.ret1.([]byte), m.err -} - -func (m *MockNoSQLDB) Set(bucket, key, value []byte) error { - if m.set != nil { - return m.set(bucket, key, value) - } - return m.err -} - -func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error { - if m.open != nil { - return m.open(dataSourceName, opt...) - } - return m.err -} - -func (m *MockNoSQLDB) Close() error { - if m.close != nil { - return m.close() - } - return m.err -} - -func (m *MockNoSQLDB) CreateTable(bucket []byte) error { - if m.createTable != nil { - return m.createTable(bucket) - } - return m.err -} - -func (m *MockNoSQLDB) DeleteTable(bucket []byte) error { - if m.deleteTable != nil { - return m.deleteTable(bucket) - } - return m.err -} - -func (m *MockNoSQLDB) Del(bucket, key []byte) error { - if m.del != nil { - return m.del(bucket, key) - } - return m.err -} - -func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) { - if m.list != nil { - return m.list(bucket) - } - return m.ret1.([]*database.Entry), m.err -} - -func (m *MockNoSQLDB) Update(tx *database.Tx) error { - if m.update != nil { - return m.update(tx) - } - return m.err -} - func TestIsRevoked(t *testing.T) { tests := map[string]struct { key string @@ -111,16 +20,16 @@ func TestIsRevoked(t *testing.T) { }, "false/ErrNotFound": { key: "sn", - db: &DB{&MockNoSQLDB{err: database.ErrNotFound, ret1: nil}, true}, + db: &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true}, }, "error/checking bucket": { key: "sn", - db: &DB{&MockNoSQLDB{err: errors.New("force"), ret1: nil}, true}, + db: &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true}, err: errors.New("error checking revocation bucket: force"), }, "true": { key: "sn", - db: &DB{&MockNoSQLDB{ret1: []byte("value")}, true}, + db: &DB{&MockNoSQLDB{Ret1: []byte("value")}, true}, isRevoked: true, }, } @@ -148,41 +57,26 @@ func TestRevoke(t *testing.T) { "error/force isRevoked": { rci: &RevokedCertificateInfo{Serial: "sn"}, db: &DB{&MockNoSQLDB{ - get: func(bucket []byte, sn []byte) ([]byte, error) { - return nil, errors.New("force IsRevoked") + MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") }, }, true}, - err: errors.New("error checking revocation bucket: force IsRevoked"), + err: errors.New("error AuthDB CmpAndSwap: force"), }, "error/was already revoked": { rci: &RevokedCertificateInfo{Serial: "sn"}, db: &DB{&MockNoSQLDB{ - get: func(bucket []byte, sn []byte) ([]byte, error) { - return nil, nil + MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil }, }, true}, err: ErrAlreadyExists, }, - "error/database set": { - rci: &RevokedCertificateInfo{Serial: "sn"}, - db: &DB{&MockNoSQLDB{ - get: func(bucket []byte, sn []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - set: func(bucket []byte, key []byte, value []byte) error { - return errors.New("force") - }, - }, true}, - err: errors.New("database Set error: force"), - }, "ok": { rci: &RevokedCertificateInfo{Serial: "sn"}, db: &DB{&MockNoSQLDB{ - get: func(bucket []byte, sn []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - set: func(bucket []byte, key []byte, value []byte) error { - return nil + MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil }, }, true}, }, @@ -214,7 +108,7 @@ func TestUseToken(t *testing.T) { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ - cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, true}, @@ -227,7 +121,7 @@ func TestUseToken(t *testing.T) { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ - cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), false, nil }, }, true}, @@ -239,7 +133,7 @@ func TestUseToken(t *testing.T) { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ - cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("bar"), true, nil }, }, true}, diff --git a/db/simple.go b/db/simple.go index 657a518f..30c2b124 100644 --- a/db/simple.go +++ b/db/simple.go @@ -6,6 +6,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/nosql/database" ) // ErrNotImplemented is an error returned when an operation is Not Implemented. @@ -61,3 +62,57 @@ func (s *SimpleDB) UseToken(id, tok string) (bool, error) { func (s *SimpleDB) Shutdown() error { return nil } + +// nosql.DB interface implementation // + +// Open opens the database available with the given options. +func (s *SimpleDB) Open(dataSourceName string, opt ...database.Option) error { + return ErrNotImplemented +} + +// Close closes the current database. +func (s *SimpleDB) Close() error { + return ErrNotImplemented +} + +// Get returns the value stored in the given table/bucket and key. +func (s *SimpleDB) Get(bucket, key []byte) ([]byte, error) { + return nil, ErrNotImplemented +} + +// Set sets the given value in the given table/bucket and key. +func (s *SimpleDB) Set(bucket, key, value []byte) error { + return ErrNotImplemented +} + +// CmpAndSwap swaps the value at the given bucket and key if the current +// value is equivalent to the oldValue input. Returns 'true' if the +// swap was successful and 'false' otherwise. +func (s *SimpleDB) CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) { + return nil, false, ErrNotImplemented +} + +// Del deletes the data in the given table/bucket and key. +func (s *SimpleDB) Del(bucket, key []byte) error { + return ErrNotImplemented +} + +// List returns a list of all the entries in a given table/bucket. +func (s *SimpleDB) List(bucket []byte) ([]*database.Entry, error) { + return nil, ErrNotImplemented +} + +// Update performs a transaction with multiple read-write commands. +func (s *SimpleDB) Update(tx *database.Tx) error { + return ErrNotImplemented +} + +// CreateTable creates a table or a bucket in the database. +func (s *SimpleDB) CreateTable(bucket []byte) error { + return ErrNotImplemented +} + +// DeleteTable deletes a table or a bucket in the database. +func (s *SimpleDB) DeleteTable(bucket []byte) error { + return ErrNotImplemented +} diff --git a/docs/acme.md b/docs/acme.md new file mode 100644 index 00000000..072aa0fa --- /dev/null +++ b/docs/acme.md @@ -0,0 +1,160 @@ +# Using ACME with `step-ca ` + +Let’s assume you’ve [installed +`step-ca`](https://smallstep.com/docs/getting-started/#1-installing-step-and-step-ca) +(e.g., using `brew install step`), have it running at `https://ca.internal`, +and you’ve [bootstrapped your ACME client +system(s)](https://smallstep.com/docs/getting-started/#bootstrapping) (or at +least [installed your root +certificate](https://smallstep.com/docs/cli/ca/root/) at +`~/.step/certs/root_ca.crt`). + +## Enabling ACME + +To enable ACME, simply [add an ACME provisioner](https://smallstep.com/docs/cli/ca/provisioner/add/) to your `step-ca` configuration +by running: + +``` +$ step ca provisioner add my-acme-provisioner --type ACME +``` + +> NOTE: The above command will add a new provisioner of type `ACME` and name +> `my-acme-provisioner`. The name is used to identify the provisioner +> (e.g. you cannot have two `ACME` provisioners with the same name). + +Now restart or SIGHUP `step-ca` to pick up the new configuration. + +That’s it. + +## Configuring Clients + +To configure an ACME client to connect to `step-ca` you need to: + +1. Point the client at the right ACME directory URL +2. Tell the client to trust your CA’s root certificate + +Once certificates are issued, you’ll also need to ensure they’re renewed before +they expire. + +### Pointing Clients at the right ACME Directory URL + +Most ACME clients connect to Let’s Encrypt by default. To connect to `step-ca` +you need to point the client at the right [ACME directory +URL](https://tools.ietf.org/html/rfc8555#section-7.1.1). + +A single instance of `step-ca` can have multiple ACME provisioners, each with +their own ACME directory URL that looks like: + +``` +https://{ca-host}/acme/{provisioner-name}/directory +``` + +We just added an ACME provisioner named “acme”. Its ACME directory URL is: + +``` +https://ca.internal/acme/acme/directory +``` + +### Telling clients to trust your CA’s root certificate + +Communication between an ACME client and server [always uses +HTTPS](https://tools.ietf.org/html/rfc8555#section-6.1). By default, client’s +will validate the server’s HTTPS certificate using the public root certificates +in your system’s [default +trust](https://smallstep.com/blog/everything-pki.html#trust-stores) store. +That’s fine when you’re connecting to Let’s Encrypt: it’s a public CA and its +root certificate is in your system’s default trust store already. Your internal +root certificate isn’t, so HTTPS connections from ACME clients to `step-ca` will +fail. + +There are two ways to address this problem: + +1. Explicitly configure your ACME client to trust `step-ca`'s root certificate, or +2. Add `step-ca`'s root certificate to your system’s default trust store (e.g., + using `[step certificate + install](https://smallstep.com/docs/cli/certificate/install/)`) + +If you’re using your CA for TLS in production, explicitly configuring your ACME +client to only trust your root certificate is a better option. We’ll +demonstrate this method with several clients below. + +If you’re simulating Let’s Encrypt in pre-production, installing your root +certificate is a more faithful simulation of production. Once your root +certificate is installed, no additional client configuration is necessary. + +> Caution: adding a root certificate to your system’s trust store is a global +> operation. Certificates issued by your CA will be trusted everywhere, +> including in web browsers. + +### Example using [`certbot`](https://certbot.eff.org/) + +[`certbot`](https://certbot.eff.org/) is the grandaddy of ACME clients. Built +and supported by [the EFF](https://www.eff.org/), it’s the standard-bearer for +production-grade command-line ACME. + +To get a certificate from `step-ca` using `certbot` you need to: + +1. Point `certbot` at your ACME directory URL using the `--`server flag. +2. Tell `certbot` to trust your root certificate using the `REQUESTS_CA_BUNDLE` environment variable. + +For example: + +``` +$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt \ + certbot certonly -n --standalone -d foo.internal \ + --server https://ca.internal/acme/acme/directory +``` + +`sudo` is required in `certbot`'s [*standalone* +mode](https://certbot.eff.org/docs/using.html#standalone) so it can listen on +port 80 to complete the `http-01` challenge. If you already have a webserver +running you can use [*webroot* +mode](https://certbot.eff.org/docs/using.html#webroot) instead. With the +[appropriate plugin](https://certbot.eff.org/docs/using.html#dns-plugins) +`certbot` also supports the `dns-01` challenge for most popular DNS providers. +Deeper integrations with [nginx](https://certbot.eff.org/docs/using.html#nginx) +and [apache](https://certbot.eff.org/docs/using.html#apache) can even configure +your server to use HTTPS automatically (we'll set this up ourselves later). All +of this works with `step-ca`. + +You can renew all of the certificates you've installed using `cerbot` by running: + +``` +$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot renew +``` + +You can automate renewal with a simple `cron` entry: + +``` +*/15 * * * * root REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot -q renew +``` + +The `certbot` packages for some Linux distributions will create a `cron` entry +or [systemd +timer](https://stevenwestmoreland.com/2017/11/renewing-certbot-certificates-using-a-systemd-timer.html) +like this for you. This entry won't work with `step-ca` because it [doesn't set +the `REQUESTS_CA_BUNDLE` environment +variable](https://github.com/certbot/certbot/issues/7170). You'll need to +manually tweak it to do so. + +More subtly, `certbot`'s default renewal job is tuned for Let's Encrypt's 90 +day certificate lifetimes: it's run every 12 hours, with actual renewals +occurring for certificates within 30 days of expiry. By default, `step-ca` +issues certificates with *much shorter* 24 hour lifetimes. The `cron` entry +above accounts for this by running `certbot renew` every 15 minutes. You'll +also want to configure your domain to only renew certificates when they're +within a few hours of expiry by adding a line like: + +``` +renew_before_expiry = 8 hours +``` + +to the top of your renewal configuration (e.g., in `/etc/letsencrypt/renewal/foo.internal.conf`). + +## Feedback + +`step-ca` should work with any ACMEv2 +([RFC8555](https://tools.ietf.org/html/rfc8555)) compliant client that supports +the http-01 or dns-01 challenge. If you run into any issues please let us know +[on gitter](https://gitter.im/smallstep/community) or [in an +issue](https://github.com/smallstep/certificates/issues/new?template=bug_report.md).