From 7731edd816bd40fcde79dd3806a45c7d85d32a0f Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 6 Jun 2023 23:37:51 -0700 Subject: [PATCH] Store and verify Acme account location (#1386) * Store and verify account location on acme requests Co-authored-by: Herman Slatman Co-authored-by: Mariano Cano --- acme/account.go | 11 ++++ acme/account_test.go | 17 ++++++ acme/api/account.go | 17 ++++-- acme/api/middleware.go | 53 +++++++++++++++--- acme/api/middleware_test.go | 101 ++++++++++++++++++++++++---------- acme/db.go | 6 ++ acme/db/nosql/account.go | 36 +++++++----- acme/db/nosql/account_test.go | 87 +++++++++++++++++------------ 8 files changed, 237 insertions(+), 91 deletions(-) diff --git a/acme/account.go b/acme/account.go index fa4b1167..38cca218 100644 --- a/acme/account.go +++ b/acme/account.go @@ -20,6 +20,16 @@ type Account struct { Status Status `json:"status"` OrdersURL string `json:"orders"` ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"` + LocationPrefix string `json:"-"` + ProvisionerName string `json:"-"` +} + +// GetLocation returns the URL location of the given account. +func (a *Account) GetLocation() string { + if a.LocationPrefix == "" { + return "" + } + return a.LocationPrefix + a.ID } // ToLog enables response logging. @@ -72,6 +82,7 @@ func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions { IPRanges: p.X509.Allowed.IPRanges, } } + func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions { if p == nil { return nil diff --git a/acme/account_test.go b/acme/account_test.go index b8ce7276..d4122500 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -66,6 +66,23 @@ func TestKeyToID(t *testing.T) { } } +func TestAccount_GetLocation(t *testing.T) { + locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/" + type test struct { + acc *Account + exp string + } + tests := map[string]test{ + "empty": {acc: &Account{LocationPrefix: ""}, exp: ""}, + "not-empty": {acc: &Account{ID: "bar", LocationPrefix: locationPrefix}, exp: locationPrefix + "bar"}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + assert.Equals(t, tc.acc.GetLocation(), tc.exp) + }) + } +} + func TestAccount_IsValid(t *testing.T) { type test struct { acc *Account diff --git a/acme/api/account.go b/acme/api/account.go index 954cb9de..ce8b5799 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "errors" "net/http" @@ -67,6 +68,12 @@ func (u *UpdateAccountRequest) Validate() error { } } +// getAccountLocationPath returns the current account URL location. +// Returned location will be of the form: https:///acme//account/ +func getAccountLocationPath(ctx context.Context, linker acme.Linker, accID string) string { + return linker.GetLink(ctx, acme.AccountLinkType, accID) +} + // NewAccount is the handler resource for creating new ACME accounts. func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -125,9 +132,11 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { } acc = &acme.Account{ - Key: jwk, - Contact: nar.Contact, - Status: acme.StatusValid, + Key: jwk, + Contact: nar.Contact, + Status: acme.StatusValid, + LocationPrefix: getAccountLocationPath(ctx, linker, ""), + ProvisionerName: prov.GetName(), } if err := db.CreateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating account")) @@ -152,7 +161,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { linker.LinkAccount(ctx, acc) - w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID)) + w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID)) render.JSONStatus(w, acc, httpStatus) } diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 5dcb93e3..ab2ab908 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "path" "strings" "go.step.sm/crypto/jose" @@ -16,7 +17,6 @@ import ( "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" - "github.com/smallstep/nosql" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -293,7 +293,6 @@ func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) - linker := acme.MustLinkerFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -301,19 +300,16 @@ func lookupJWK(next nextHTTP) nextHTTP { return } - kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID - if !strings.HasPrefix(kid, kidPrefix) { - render.Error(w, acme.NewError(acme.ErrorMalformedType, - "kid does not have required prefix; expected %s, but got %s", - kidPrefix, kid)) + if kid == "" { + render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'")) return } - accID := strings.TrimPrefix(kid, kidPrefix) + accID := path.Base(kid) acc, err := db.GetAccount(ctx, accID) switch { - case nosql.IsErrNotFound(err): + case acme.IsErrNotFound(err): render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: @@ -324,6 +320,45 @@ func lookupJWK(next nextHTTP) nextHTTP { render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } + + if storedLocation := acc.GetLocation(); storedLocation != "" { + if kid != storedLocation { + // ACME accounts should have a stored location equivalent to the + // kid in the ACME request. + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + "kid does not match stored account location; expected %s, but got %s", + storedLocation, kid)) + return + } + + // Verify that the provisioner with which the account was created + // matches the provisioner in the request URL. + reqProv := acme.MustProvisionerFromContext(ctx) + reqProvName := reqProv.GetName() + accProvName := acc.ProvisionerName + if reqProvName != accProvName { + // Provisioner in the URL must match the provisioner with + // which the account was created. + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", + accProvName, reqProvName)) + return + } + } else { + // This code will only execute for old ACME accounts that do + // not have a cached location. The following validation was + // the original implementation of the `kid` check which has + // since been deprecated. However, the code will remain to + // ensure consistent behavior for old ACME accounts. + linker := acme.MustLinkerFromContext(ctx) + kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") + if !strings.HasPrefix(kid, kidPrefix) { + render.Error(w, acme.NewError(acme.ErrorMalformedType, + "kid does not have required prefix; expected %s, but got %s", + kidPrefix, kid)) + return + } + } ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, jwkContextKey, acc.Key) next(w, r.WithContext(ctx)) diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 6e9587f5..f7db647b 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -17,7 +17,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/nosql/database" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" ) @@ -678,31 +677,7 @@ func TestHandler_lookupJWK(t *testing.T) { linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), - } - }, - "fail/bad-kid-prefix": func(t *testing.T) test { - _so := new(jose.SignerOptions) - _so.WithHeader("kid", "foo") - _signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), - Key: jwk.Key, - }, _so) - assert.FatalError(t, err) - _jws, err := _signer.Sign([]byte("baz")) - assert.FatalError(t, err) - _raw, err := _jws.CompactSerialize() - assert.FatalError(t, err) - _parsed, err := jose.ParseJWS(_raw) - assert.FatalError(t, err) - ctx := acme.NewProvisionerContext(context.Background(), prov) - ctx = context.WithValue(ctx, jwsContextKey, _parsed) - return test{ - db: &acme.MockDB{}, - linker: acme.NewLinker("test.ca.smallstep.com", "acme"), - ctx: ctx, - statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), + err: acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"), } }, "fail/account-not-found": func(t *testing.T) test { @@ -713,7 +688,7 @@ func TestHandler_lookupJWK(t *testing.T) { db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) - return nil, database.ErrNotFound + return nil, acme.ErrNotFound }, }, ctx: ctx, @@ -754,7 +729,77 @@ func TestHandler_lookupJWK(t *testing.T) { err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, - "ok": func(t *testing.T) test { + "fail/account-with-location-prefix/bad-kid": func(t *testing.T) test { + acc := &acme.Account{LocationPrefix: "foobar", Status: "valid"} + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) + return acc, nil + }, + }, + ctx: ctx, + statusCode: http.StatusUnauthorized, + err: acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected foobar, but %q", prefix+accID), + } + }, + "fail/account-with-location-prefix/bad-provisioner": func(t *testing.T) test { + acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: "other"} + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) + return acc, nil + }, + }, + ctx: ctx, + next: func(w http.ResponseWriter, r *http.Request) { + _acc, err := accountFromContext(r.Context()) + assert.FatalError(t, err) + assert.Equals(t, _acc, acc) + _jwk, err := jwkFromContext(r.Context()) + assert.FatalError(t, err) + assert.Equals(t, _jwk, jwk) + w.Write(testBody) + }, + statusCode: http.StatusUnauthorized, + err: acme.NewError(acme.ErrorUnauthorizedType, + "account provisioner does not match requested provisioner; account provisioner = %s, reqested provisioner = %s", + prov.GetName(), "other"), + } + }, + "ok/account-with-location-prefix": func(t *testing.T) test { + acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: prov.GetName()} + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) + return acc, nil + }, + }, + ctx: ctx, + next: func(w http.ResponseWriter, r *http.Request) { + _acc, err := accountFromContext(r.Context()) + assert.FatalError(t, err) + assert.Equals(t, _acc, acc) + _jwk, err := jwkFromContext(r.Context()) + assert.FatalError(t, err) + assert.Equals(t, _jwk, jwk) + w.Write(testBody) + }, + statusCode: http.StatusOK, + } + }, + "ok/account-without-location-prefix": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) diff --git a/acme/db.go b/acme/db.go index d7c9d5f4..fa9aa0de 100644 --- a/acme/db.go +++ b/acme/db.go @@ -12,6 +12,12 @@ import ( // account. var ErrNotFound = errors.New("not found") +// IsErrNotFound returns true if the error is a "not found" error. Returns false +// otherwise. +func IsErrNotFound(err error) bool { + return errors.Is(err, ErrNotFound) +} + // DB is the DB interface expected by the step-ca ACME API. type DB interface { CreateAccount(ctx context.Context, acc *Account) error diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 8067a4b9..d590ccb3 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -13,12 +13,14 @@ import ( // dbAccount represents an ACME account. type dbAccount struct { - ID string `json:"id"` - Key *jose.JSONWebKey `json:"key"` - Contact []string `json:"contact,omitempty"` - Status acme.Status `json:"status"` - CreatedAt time.Time `json:"createdAt"` - DeactivatedAt time.Time `json:"deactivatedAt"` + ID string `json:"id"` + Key *jose.JSONWebKey `json:"key"` + Contact []string `json:"contact,omitempty"` + Status acme.Status `json:"status"` + LocationPrefix string `json:"locationPrefix"` + ProvisionerName string `json:"provisionerName"` + CreatedAt time.Time `json:"createdAt"` + DeactivatedAt time.Time `json:"deactivatedAt"` } func (dba *dbAccount) clone() *dbAccount { @@ -62,10 +64,12 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) } return &acme.Account{ - Status: dbacc.Status, - Contact: dbacc.Contact, - Key: dbacc.Key, - ID: dbacc.ID, + Status: dbacc.Status, + Contact: dbacc.Contact, + Key: dbacc.Key, + ID: dbacc.ID, + LocationPrefix: dbacc.LocationPrefix, + ProvisionerName: dbacc.ProvisionerName, }, nil } @@ -87,11 +91,13 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { } dba := &dbAccount{ - ID: acc.ID, - Key: acc.Key, - Contact: acc.Contact, - Status: acc.Status, - CreatedAt: clock.Now(), + ID: acc.ID, + Key: acc.Key, + Contact: acc.Contact, + Status: acc.Status, + CreatedAt: clock.Now(), + LocationPrefix: acc.LocationPrefix, + ProvisionerName: acc.ProvisionerName, } kid, err := acme.KeyToID(dba.Key) diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go index 6097cc5a..085ce2eb 100644 --- a/acme/db/nosql/account_test.go +++ b/acme/db/nosql/account_test.go @@ -197,6 +197,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) { func TestDB_GetAccount(t *testing.T) { accID := "accID" + locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/" + provisionerName := "foo" type test struct { db nosql.DB err error @@ -222,12 +224,14 @@ func TestDB_GetAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) dbacc := &dbAccount{ - ID: accID, - Status: acme.StatusDeactivated, - CreatedAt: now, - DeactivatedAt: now, - Contact: []string{"foo", "bar"}, - Key: jwk, + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + LocationPrefix: locationPrefix, + ProvisionerName: provisionerName, } b, err := json.Marshal(dbacc) assert.FatalError(t, err) @@ -266,6 +270,8 @@ func TestDB_GetAccount(t *testing.T) { assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.LocationPrefix, tc.dbacc.LocationPrefix) + assert.Equals(t, acc.ProvisionerName, tc.dbacc.ProvisionerName) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) } }) @@ -379,6 +385,7 @@ func TestDB_GetAccountByKeyID(t *testing.T) { } func TestDB_CreateAccount(t *testing.T) { + locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/" type test struct { db nosql.DB acc *acme.Account @@ -390,9 +397,10 @@ func TestDB_CreateAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ - Status: acme.StatusValid, - Contact: []string{"foo", "bar"}, - Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ @@ -413,9 +421,10 @@ func TestDB_CreateAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ - Status: acme.StatusValid, - Contact: []string{"foo", "bar"}, - Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ @@ -436,9 +445,10 @@ func TestDB_CreateAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ - Status: acme.StatusValid, - Contact: []string{"foo", "bar"}, - Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ @@ -456,6 +466,8 @@ func TestDB_CreateAccount(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbacc)) assert.Equals(t, dbacc.ID, string(key)) assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix) + assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName) assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) @@ -479,9 +491,10 @@ func TestDB_CreateAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ - Status: acme.StatusValid, - Contact: []string{"foo", "bar"}, - Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ @@ -500,6 +513,8 @@ func TestDB_CreateAccount(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbacc)) assert.Equals(t, dbacc.ID, string(key)) assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix) + assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName) assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) @@ -539,12 +554,14 @@ func TestDB_UpdateAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) dbacc := &dbAccount{ - ID: accID, - Status: acme.StatusDeactivated, - CreatedAt: now, - DeactivatedAt: now, - Contact: []string{"foo", "bar"}, - Key: jwk, + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + LocationPrefix: "foo", + ProvisionerName: "alpha", + Key: jwk, } b, err := json.Marshal(dbacc) assert.FatalError(t, err) @@ -644,10 +661,12 @@ func TestDB_UpdateAccount(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ - ID: accID, - Status: acme.StatusDeactivated, - Contact: []string{"foo", "bar"}, - Key: jwk, + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"baz", "zap"}, + LocationPrefix: "bar", + ProvisionerName: "beta", + Key: jwk, } return test{ acc: acc, @@ -666,7 +685,10 @@ func TestDB_UpdateAccount(t *testing.T) { assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbacc.ID) assert.Equals(t, dbNew.Status, acc.Status) - assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Contact, acc.Contact) + // LocationPrefix should not change. + assert.Equals(t, dbNew.LocationPrefix, dbacc.LocationPrefix) + assert.Equals(t, dbNew.ProvisionerName, dbacc.ProvisionerName) assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) @@ -686,12 +708,7 @@ func TestDB_UpdateAccount(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, dbacc.ID) - assert.Equals(t, tc.acc.Status, dbacc.Status) - assert.Equals(t, tc.acc.Contact, dbacc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID) - } + assert.Nil(t, tc.err) } }) }