commit
b724af30ad
@ -1,197 +1,42 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// Account is a subset of the internal account type containing only those
|
||||
// attributes required for responses in the ACME protocol.
|
||||
type Account struct {
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Orders string `json:"orders"`
|
||||
ID string `json:"-"`
|
||||
Key *jose.JSONWebKey `json:"-"`
|
||||
ID string `json:"-"`
|
||||
Key *jose.JSONWebKey `json:"-"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status Status `json:"status"`
|
||||
OrdersURL string `json:"orders"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Account) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
|
||||
return nil, WrapErrorISE(err, "error marshaling account for logging")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the account ID.
|
||||
func (a *Account) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// GetKey returns the JWK associated with the account.
|
||||
func (a *Account) GetKey() *jose.JSONWebKey {
|
||||
return a.Key
|
||||
}
|
||||
|
||||
// IsValid returns true if the Account is valid.
|
||||
func (a *Account) IsValid() bool {
|
||||
return a.Status == StatusValid
|
||||
}
|
||||
|
||||
// AccountOptions are the options needed to create a new ACME account.
|
||||
type AccountOptions struct {
|
||||
Key *jose.JSONWebKey
|
||||
Contact []string
|
||||
}
|
||||
|
||||
// account represents an ACME account.
|
||||
type account struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
Deactivated time.Time `json:"deactivated"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// newAccount returns a new acme account type.
|
||||
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &account{
|
||||
ID: id,
|
||||
Key: ops.Key,
|
||||
Contact: ops.Contact,
|
||||
Status: "valid",
|
||||
Created: clock.Now(),
|
||||
}
|
||||
return a, a.saveNew(db)
|
||||
}
|
||||
|
||||
// toACME converts the internal Account type into the public acmeAccount
|
||||
// type for presentation in the ACME protocol.
|
||||
func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) {
|
||||
return &Account{
|
||||
Status: a.Status,
|
||||
Contact: a.Contact,
|
||||
Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID),
|
||||
Key: a.Key,
|
||||
ID: a.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// save writes the Account to the DB.
|
||||
// If the account is new then the necessary indices will be created.
|
||||
// Else, the account in the DB will be updated.
|
||||
func (a *account) saveNew(db nosql.DB) error {
|
||||
kid, err := keyToID(a.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kidB := []byte(kid)
|
||||
|
||||
// Set the jwkID -> acme account ID index
|
||||
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
|
||||
default:
|
||||
if err = a.save(db, nil); err != nil {
|
||||
db.Del(accountByKeyIDTable, kidB)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *account) save(db nosql.DB, old *account) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(*a)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling new account object")
|
||||
}
|
||||
// Set the Account
|
||||
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing account"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing account; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// update updates the acme account object stored in the database if,
|
||||
// and only if, the account has not changed since the last read.
|
||||
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
|
||||
b := *a
|
||||
b.Contact = contact
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// deactivate deactivates the acme account.
|
||||
func (a *account) deactivate(db nosql.DB) (*account, error) {
|
||||
b := *a
|
||||
b.Status = StatusDeactivated
|
||||
b.Deactivated = clock.Now()
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// getAccountByID retrieves the account with the given ID.
|
||||
func getAccountByID(db nosql.DB, id string) (*account, error) {
|
||||
ab, err := db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
|
||||
}
|
||||
|
||||
a := new(account)
|
||||
if err = json.Unmarshal(ab, a); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
|
||||
}
|
||||
return a, nil
|
||||
return Status(a.Status) == StatusValid
|
||||
}
|
||||
|
||||
// getAccountByKeyID retrieves Id associated with the given Kid.
|
||||
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
|
||||
id, err := db.Get(accountByKeyIDTable, []byte(kid))
|
||||
// KeyToID converts a JWK to a thumbprint.
|
||||
func KeyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
|
||||
return "", WrapErrorISE(err, "error generating jwk thumbprint")
|
||||
}
|
||||
return getAccountByID(db, string(id))
|
||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
@ -1,770 +1,81 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func newProv() Provisioner {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func newAcc() (*account, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newAccount(mockdb, AccountOptions{
|
||||
Key: jwk, Contact: []string{"foo", "bar"},
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAccountByID(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountByKeyID(t *testing.T) {
|
||||
type test struct {
|
||||
kid string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/kid-not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading key-account index: force")),
|
||||
}
|
||||
},
|
||||
"fail/getAccount-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
count++
|
||||
return []byte("bar"), nil
|
||||
}
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading account bar: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
kid: acc.Key.KeyID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
var ret []byte
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(acc.Key.KeyID))
|
||||
ret = []byte(acc.ID)
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
ret = b
|
||||
}
|
||||
count++
|
||||
return ret, nil
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
|
||||
|
||||
type test struct {
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{acc: acc}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAccount, err := tc.acc.toACME(ctx, nil, dir)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
|
||||
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
|
||||
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acmeAccount.Orders,
|
||||
fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSave(t *testing.T) {
|
||||
type test struct {
|
||||
acc, old *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAcc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAcc)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: oldAcc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSaveNew(t *testing.T) {
|
||||
func TestKeyToID(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
jwk *jose.JSONWebKey
|
||||
exp string
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/keyToID-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
"fail/error-generating-thumbprint": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc.Key.Key = "foo"
|
||||
jwk.Key = "foo"
|
||||
return test{
|
||||
acc: acc,
|
||||
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
|
||||
}
|
||||
},
|
||||
"fail/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"fail/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
MDel: func(bucket, key []byte) error {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
return nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
jwk: jwk,
|
||||
err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.saveNew(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUpdate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
contact []string
|
||||
db nosql.DB
|
||||
res []byte
|
||||
err *Error
|
||||
}
|
||||
contact := []string{"foo", "bar"}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
res: b,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.update(tc.db, tc.contact)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, b, tc.res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeactivate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
jwk: jwk,
|
||||
exp: base64.RawURLEncoding.EncodeToString(kid),
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.deactivate(tc.db)
|
||||
if err != nil {
|
||||
if id, err := KeyToID(tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.acc.ID)
|
||||
assert.Equals(t, acc.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acc.Status, StatusDeactivated)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acc.Created, tc.acc.Created)
|
||||
|
||||
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
|
||||
assert.Equals(t, id, tc.exp)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAccount(t *testing.T) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(jwk)
|
||||
assert.FatalError(t, err)
|
||||
ops := AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
func TestAccount_IsValid(t *testing.T) {
|
||||
type test struct {
|
||||
ops AccountOptions
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
acc *Account
|
||||
exp bool
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/store-error": func(t *testing.T) test {
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
count := 0
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
*id = string(key)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
tests := map[string]test{
|
||||
"valid": {acc: &Account{Status: StatusValid}, exp: true},
|
||||
"invalid": {acc: &Account{Status: StatusInvalid}, exp: false},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acc, err := newAccount(tc.db, tc.ops)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, *tc.id)
|
||||
assert.Equals(t, acc.Status, StatusValid)
|
||||
assert.Equals(t, acc.Contact, ops.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
|
||||
|
||||
assert.True(t, acc.Deactivated.IsZero())
|
||||
|
||||
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
}
|
||||
}
|
||||
assert.Equals(t, tc.acc.IsValid(), tc.exp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,182 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string
|
||||
GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string
|
||||
|
||||
LinkOrder(ctx context.Context, o *acme.Order)
|
||||
LinkAccount(ctx context.Context, o *acme.Account)
|
||||
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string {
|
||||
var provName string
|
||||
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
||||
provName = p.GetName()
|
||||
}
|
||||
return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...)
|
||||
}
|
||||
|
||||
// GetLinkExplicit returns an absolute or partial path to the given resource and a base
|
||||
// URL dynamically obtained from the request for which the link is being
|
||||
// calculated.
|
||||
func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
|
||||
var u = url.URL{}
|
||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||
if baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
u.Path = fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
u.Path = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
u.Path = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
u.Path = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
u.Path = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
}
|
||||
|
||||
if abs {
|
||||
// If no Scheme is set, then default to https.
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + u.Path
|
||||
return u.String()
|
||||
}
|
||||
return u.EscapedPath()
|
||||
}
|
||||
|
||||
// LinkType captures the link type.
|
||||
type LinkType int
|
||||
|
||||
const (
|
||||
// NewNonceLinkType new-nonce
|
||||
NewNonceLinkType LinkType = iota
|
||||
// NewAccountLinkType new-account
|
||||
NewAccountLinkType
|
||||
// AccountLinkType account
|
||||
AccountLinkType
|
||||
// OrderLinkType order
|
||||
OrderLinkType
|
||||
// NewOrderLinkType new-order
|
||||
NewOrderLinkType
|
||||
// OrdersByAccountLinkType list of orders owned by account
|
||||
OrdersByAccountLinkType
|
||||
// FinalizeLinkType finalize order
|
||||
FinalizeLinkType
|
||||
// NewAuthzLinkType authz
|
||||
NewAuthzLinkType
|
||||
// AuthzLinkType new-authz
|
||||
AuthzLinkType
|
||||
// ChallengeLinkType challenge
|
||||
ChallengeLinkType
|
||||
// CertificateLinkType certificate
|
||||
CertificateLinkType
|
||||
// DirectoryLinkType directory
|
||||
DirectoryLinkType
|
||||
// RevokeCertLinkType revoke certificate
|
||||
RevokeCertLinkType
|
||||
// KeyChangeLinkType key rollover
|
||||
KeyChangeLinkType
|
||||
)
|
||||
|
||||
func (l LinkType) String() string {
|
||||
switch l {
|
||||
case NewNonceLinkType:
|
||||
return "new-nonce"
|
||||
case NewAccountLinkType:
|
||||
return "new-account"
|
||||
case AccountLinkType:
|
||||
return "account"
|
||||
case NewOrderLinkType:
|
||||
return "new-order"
|
||||
case OrderLinkType:
|
||||
return "order"
|
||||
case NewAuthzLinkType:
|
||||
return "new-authz"
|
||||
case AuthzLinkType:
|
||||
return "authz"
|
||||
case ChallengeLinkType:
|
||||
return "challenge"
|
||||
case CertificateLinkType:
|
||||
return "certificate"
|
||||
case DirectoryLinkType:
|
||||
return "directory"
|
||||
case RevokeCertLinkType:
|
||||
return "revoke-cert"
|
||||
case KeyChangeLinkType:
|
||||
return "key-change"
|
||||
default:
|
||||
return fmt.Sprintf("unexpected LinkType '%d'", int(l))
|
||||
}
|
||||
}
|
||||
|
||||
// LinkOrder sets the ACME links required by an ACME order.
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||
for i, azID := range o.AuthorizationIDs {
|
||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID)
|
||||
}
|
||||
o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID)
|
||||
if o.CertificateID != "" {
|
||||
o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID)
|
||||
}
|
||||
}
|
||||
|
||||
// LinkAccount sets the ACME links required by an ACME account.
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID)
|
||||
}
|
||||
|
||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID)
|
||||
}
|
||||
|
||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
||||
for _, ch := range az.Challenges {
|
||||
l.LinkChallenge(ctx, ch, az.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// LinkOrdersByAccountID converts each order ID to an ACME link.
|
||||
func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) {
|
||||
for i, id := range orders {
|
||||
orders[i] = l.GetLink(ctx, OrderLinkType, true, id)
|
||||
}
|
||||
}
|
@ -0,0 +1,302 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
func TestLinker_GetLink(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, true),
|
||||
fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName))
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName))
|
||||
|
||||
// No provisioner
|
||||
ctxNoProv := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, true),
|
||||
fmt.Sprintf("%s/acme//new-nonce", baseURL.String()))
|
||||
assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, false), "//new-nonce")
|
||||
|
||||
// No baseURL
|
||||
ctxNoBaseURL := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, true),
|
||||
fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName))
|
||||
assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, true, id),
|
||||
fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName))
|
||||
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName))
|
||||
}
|
||||
|
||||
func TestLinker_GetLinkExplicit(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provName := prov.GetName()
|
||||
escProvName := url.PathEscape(provName)
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-nonce", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-account", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-order", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-authz", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, false, baseURL), fmt.Sprintf("/%s/directory", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, false, baseURL), fmt.Sprintf("/%s/revoke-cert", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, false, baseURL), fmt.Sprintf("/%s/key-change", escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
|
||||
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", escProvName, id, id))
|
||||
|
||||
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", escProvName))
|
||||
}
|
||||
|
||||
func TestLinker_LinkOrder(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
oid := "orderID"
|
||||
certID := "certID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
o *acme.Order
|
||||
validate func(o *acme.Order)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"no-authz-and-no-cert": {
|
||||
o: &acme.Order{
|
||||
ID: oid,
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||
assert.Equals(t, o.CertificateURL, "")
|
||||
},
|
||||
},
|
||||
"one-authz-and-cert": {
|
||||
o: &acme.Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
})
|
||||
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
|
||||
},
|
||||
},
|
||||
"many-authz": {
|
||||
o: &acme.Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"),
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"),
|
||||
})
|
||||
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkOrder(ctx, tc.o)
|
||||
tc.validate(tc.o)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkAccount(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
accID := "accountID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
a *acme.Account
|
||||
validate func(o *acme.Account)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
a: &acme.Account{
|
||||
ID: accID,
|
||||
},
|
||||
validate: func(a *acme.Account) {
|
||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkAccount(ctx, tc.a)
|
||||
tc.validate(tc.a)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkChallenge(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
ch *acme.Challenge
|
||||
validate func(o *acme.Challenge)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
ch: &acme.Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
validate: func(ch *acme.Challenge) {
|
||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkChallenge(ctx, tc.ch, azID)
|
||||
tc.validate(tc.ch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
chID0 := "chID-0"
|
||||
chID1 := "chID-1"
|
||||
chID2 := "chID-2"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
az *acme.Authorization
|
||||
validate func(o *acme.Authorization)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
az: &acme.Authorization{
|
||||
ID: azID,
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: chID0},
|
||||
{ID: chID1},
|
||||
{ID: chID2},
|
||||
},
|
||||
},
|
||||
validate: func(az *acme.Authorization) {
|
||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkAuthorization(ctx, tc.az)
|
||||
tc.validate(tc.az)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
oids []string
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
oids: []string{"foo", "bar", "baz"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkOrdersByAccountID(ctx, tc.oids)
|
||||
assert.Equals(t, tc.oids, []string{
|
||||
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"),
|
||||
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,342 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
database "github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// Interface is the acme authority interface.
|
||||
type Interface interface {
|
||||
GetDirectory(ctx context.Context) (*Directory, error)
|
||||
NewNonce() (string, error)
|
||||
UseNonce(string) error
|
||||
|
||||
DeactivateAccount(ctx context.Context, accID string) (*Account, error)
|
||||
GetAccount(ctx context.Context, accID string) (*Account, error)
|
||||
GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error)
|
||||
NewAccount(ctx context.Context, ao AccountOptions) (*Account, error)
|
||||
UpdateAccount(context.Context, string, []string) (*Account, error)
|
||||
|
||||
GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error)
|
||||
ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error)
|
||||
|
||||
FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error)
|
||||
GetOrder(ctx context.Context, accID string, orderID string) (*Order, error)
|
||||
GetOrdersByAccount(ctx context.Context, accID string) ([]string, error)
|
||||
NewOrder(ctx context.Context, oo OrderOptions) (*Order, error)
|
||||
|
||||
GetCertificate(string, string) ([]byte, error)
|
||||
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string
|
||||
GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string
|
||||
}
|
||||
|
||||
// Authority is the layer that handles all ACME interactions.
|
||||
type Authority struct {
|
||||
backdate provisioner.Duration
|
||||
db nosql.DB
|
||||
dir *directory
|
||||
signAuth SignAuthority
|
||||
}
|
||||
|
||||
// AuthorityOptions required to create a new ACME Authority.
|
||||
type AuthorityOptions struct {
|
||||
Backdate provisioner.Duration
|
||||
// DB is the database used by nosql.
|
||||
DB nosql.DB
|
||||
// DNS the host used to generate accurate ACME links. By default the authority
|
||||
// will use the Host from the request, so this value will only be used if
|
||||
// request.Host is empty.
|
||||
DNS string
|
||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||
// prefix is required to generate accurate ACME links.
|
||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||
// "acme" is the prefix from which the ACME api is accessed.
|
||||
Prefix string
|
||||
}
|
||||
|
||||
var (
|
||||
accountTable = []byte("acme_accounts")
|
||||
accountByKeyIDTable = []byte("acme_keyID_accountID_index")
|
||||
authzTable = []byte("acme_authzs")
|
||||
challengeTable = []byte("acme_challenges")
|
||||
nonceTable = []byte("nonces")
|
||||
orderTable = []byte("acme_orders")
|
||||
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
||||
certTable = []byte("acme_certs")
|
||||
)
|
||||
|
||||
// NewAuthority returns a new Authority that implements the ACME interface.
|
||||
//
|
||||
// Deprecated: NewAuthority exists for hitorical compatibility and should not
|
||||
// be used. Use acme.New() instead.
|
||||
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) {
|
||||
return New(signAuth, AuthorityOptions{
|
||||
DB: db,
|
||||
DNS: dns,
|
||||
Prefix: prefix,
|
||||
})
|
||||
}
|
||||
|
||||
// New returns a new Autohrity that implements the ACME interface.
|
||||
func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) {
|
||||
if _, ok := ops.DB.(*database.SimpleDB); !ok {
|
||||
// If it's not a SimpleDB then go ahead and bootstrap the DB with the
|
||||
// necessary ACME tables. SimpleDB should ONLY be used for testing.
|
||||
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
||||
challengeTable, nonceTable, orderTable, ordersByAccountIDTable,
|
||||
certTable}
|
||||
for _, b := range tables {
|
||||
if err := ops.DB.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Authority{
|
||||
backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetLink returns the requested link from the directory.
|
||||
func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
|
||||
return a.dir.getLink(ctx, typ, abs, inputs...)
|
||||
}
|
||||
|
||||
// GetLinkExplicit returns the requested link from the directory.
|
||||
func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string {
|
||||
return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...)
|
||||
}
|
||||
|
||||
// GetDirectory returns the ACME directory object.
|
||||
func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) {
|
||||
return &Directory{
|
||||
NewNonce: a.dir.getLink(ctx, NewNonceLink, true),
|
||||
NewAccount: a.dir.getLink(ctx, NewAccountLink, true),
|
||||
NewOrder: a.dir.getLink(ctx, NewOrderLink, true),
|
||||
RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true),
|
||||
KeyChange: a.dir.getLink(ctx, KeyChangeLink, true),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadProvisionerByID calls out to the SignAuthority interface to load a
|
||||
// provisioner by ID.
|
||||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
return a.signAuth.LoadProvisionerByID(id)
|
||||
}
|
||||
|
||||
// NewNonce generates, stores, and returns a new ACME nonce.
|
||||
func (a *Authority) NewNonce() (string, error) {
|
||||
n, err := newNonce(a.db)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return n.ID, nil
|
||||
}
|
||||
|
||||
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
|
||||
func (a *Authority) UseNonce(nonce string) error {
|
||||
return useNonce(a.db, nonce)
|
||||
}
|
||||
|
||||
// NewAccount creates, stores, and returns a new ACME account.
|
||||
func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) {
|
||||
acc, err := newAccount(a.db, ao)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// UpdateAccount updates an ACME account.
|
||||
func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if acc, err = acc.update(a.db, contact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// GetAccount returns an ACME account.
|
||||
func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates an ACME account.
|
||||
func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if acc, err = acc.deactivate(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
func keyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
||||
// GetAccountByKey returns the ACME associated with the jwk id.
|
||||
func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) {
|
||||
kid, err := keyToID(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acc, err := getAccountByKeyID(a.db, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// GetOrder returns an ACME order.
|
||||
func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
if o, err = o.updateStatus(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// GetOrdersByAccount returns the list of order urls owned by the account.
|
||||
func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
|
||||
ordersByAccountMux.Lock()
|
||||
defer ordersByAccountMux.Unlock()
|
||||
|
||||
var oiba = orderIDsByAccount{}
|
||||
oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret = []string{}
|
||||
for _, oid := range oids {
|
||||
ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid))
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// NewOrder generates, stores, and returns a new ACME order.
|
||||
func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) {
|
||||
prov, err := ProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ops.backdate = a.backdate.Duration
|
||||
ops.defaultDuration = prov.DefaultTLSCertDuration()
|
||||
order, err := newOrder(a.db, ops)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating order")
|
||||
}
|
||||
return order.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and generate a new certificate.
|
||||
func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
|
||||
prov, err := ProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
o, err = o.finalize(a.db, csr, a.signAuth, prov)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error finalizing order")
|
||||
}
|
||||
return o.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// GetAuthz retrieves and attempts to update the status on an ACME authz
|
||||
// before returning.
|
||||
func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) {
|
||||
az, err := getAuthz(a.db, authzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != az.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own authz"))
|
||||
}
|
||||
az, err = az.updateStatus(a.db)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error updating authz status")
|
||||
}
|
||||
return az.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// ValidateChallenge attempts to validate the challenge.
|
||||
func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
|
||||
ch, err := getChallenge(a.db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != ch.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: time.Duration(30 * time.Second),
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
ch, err = ch.validate(a.db, jwk, validateOptions{
|
||||
httpGet: client.Get,
|
||||
lookupTxt: net.LookupTXT,
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(dialer, network, addr, config)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error attempting challenge validation")
|
||||
}
|
||||
return ch.toACME(ctx, a.db, a.dir)
|
||||
}
|
||||
|
||||
// GetCertificate retrieves the Certificate by ID.
|
||||
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
|
||||
cert, err := getCert(a.db, certID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != cert.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
|
||||
}
|
||||
return cert.toACME(a.db, a.dir)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,69 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Authorization representst an ACME Authorization.
|
||||
type Authorization struct {
|
||||
ID string `json:"-"`
|
||||
AccountID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status Status `json:"status"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ExpiresAt time.Time `json:"expires"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (az *Authorization) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(az)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling authz for logging")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the ACME Authorization Status if necessary.
|
||||
// Changes to the Authorization are saved using the database interface.
|
||||
func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
|
||||
now := clock.Now()
|
||||
|
||||
switch az.Status {
|
||||
case StatusInvalid:
|
||||
return nil
|
||||
case StatusValid:
|
||||
return nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(az.ExpiresAt) {
|
||||
az.Status = StatusInvalid
|
||||
break
|
||||
}
|
||||
|
||||
var isValid = false
|
||||
for _, ch := range az.Challenges {
|
||||
if ch.Status == StatusValid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return nil
|
||||
}
|
||||
az.Status = StatusValid
|
||||
az.Error = nil
|
||||
default:
|
||||
return NewErrorISE("unrecognized authorization status: %s", az.Status)
|
||||
}
|
||||
|
||||
if err := db.UpdateAuthorization(ctx, az); err != nil {
|
||||
return WrapErrorISE(err, "error updating authorization")
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,150 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestAuthorization_UpdateStatus(t *testing.T) {
|
||||
type test struct {
|
||||
az *Authorization
|
||||
err *Error
|
||||
db DB
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/already-invalid": func(t *testing.T) test {
|
||||
az := &Authorization{
|
||||
Status: StatusInvalid,
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
}
|
||||
},
|
||||
"ok/already-valid": func(t *testing.T) test {
|
||||
az := &Authorization{
|
||||
Status: StatusInvalid,
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
}
|
||||
},
|
||||
"fail/error-unexpected-status": func(t *testing.T) test {
|
||||
az := &Authorization{
|
||||
Status: "foo",
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
err: NewErrorISE("unrecognized authorization status: %s", az.Status),
|
||||
}
|
||||
},
|
||||
"ok/expired": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(-5 * time.Minute),
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
db: &MockDB{
|
||||
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||
assert.Equals(t, updaz.ID, az.ID)
|
||||
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, updaz.Status, StatusInvalid)
|
||||
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/db.UpdateAuthorization-error": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(-5 * time.Minute),
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
db: &MockDB{
|
||||
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||
assert.Equals(t, updaz.ID, az.ID)
|
||||
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, updaz.Status, StatusInvalid)
|
||||
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
err: NewErrorISE("error updating authorization: force"),
|
||||
}
|
||||
},
|
||||
"ok/no-valid-challenges": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*Challenge{
|
||||
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending},
|
||||
},
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
}
|
||||
},
|
||||
"ok/valid": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*Challenge{
|
||||
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid},
|
||||
},
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
db: &MockDB{
|
||||
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||
assert.Equals(t, updaz.ID, az.ID)
|
||||
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, updaz.Status, StatusValid)
|
||||
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||
assert.Equals(t, updaz.Error, nil)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
@ -1,347 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
var defaultExpiryDuration = time.Hour * 24
|
||||
|
||||
// Authz is a subset of the Authz type containing only those attributes
|
||||
// required for responses in the ACME protocol.
|
||||
type Authz struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Authz) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Authz ID.
|
||||
func (a *Authz) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// authz is the interface that the various authz types must implement.
|
||||
type authz interface {
|
||||
save(nosql.DB, authz) error
|
||||
clone() *baseAuthz
|
||||
getID() string
|
||||
getAccountID() string
|
||||
getType() string
|
||||
getIdentifier() Identifier
|
||||
getStatus() string
|
||||
getExpiry() time.Time
|
||||
getWildcard() bool
|
||||
getChallenges() []string
|
||||
getCreated() time.Time
|
||||
updateStatus(db nosql.DB) (authz, error)
|
||||
toACME(context.Context, nosql.DB, *directory) (*Authz, error)
|
||||
}
|
||||
|
||||
// baseAuthz is the base authz type that others build from.
|
||||
type baseAuthz struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires time.Time `json:"expires"`
|
||||
Challenges []string `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
Created time.Time `json:"created"`
|
||||
Error *Error `json:"error"`
|
||||
}
|
||||
|
||||
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
ba := &baseAuthz{
|
||||
ID: id,
|
||||
AccountID: accID,
|
||||
Status: StatusPending,
|
||||
Created: now,
|
||||
Expires: now.Add(defaultExpiryDuration),
|
||||
Identifier: identifier,
|
||||
}
|
||||
|
||||
if strings.HasPrefix(identifier.Value, "*.") {
|
||||
ba.Wildcard = true
|
||||
ba.Identifier = Identifier{
|
||||
Value: strings.TrimPrefix(identifier.Value, "*."),
|
||||
Type: identifier.Type,
|
||||
}
|
||||
}
|
||||
|
||||
return ba, nil
|
||||
}
|
||||
|
||||
// getID returns the ID of the authz.
|
||||
func (ba *baseAuthz) getID() string {
|
||||
return ba.ID
|
||||
}
|
||||
|
||||
// getAccountID returns the Account ID that created the authz.
|
||||
func (ba *baseAuthz) getAccountID() string {
|
||||
return ba.AccountID
|
||||
}
|
||||
|
||||
// getType returns the type of the authz.
|
||||
func (ba *baseAuthz) getType() string {
|
||||
return ba.Identifier.Type
|
||||
}
|
||||
|
||||
// getIdentifier returns the identifier for the authz.
|
||||
func (ba *baseAuthz) getIdentifier() Identifier {
|
||||
return ba.Identifier
|
||||
}
|
||||
|
||||
// getStatus returns the status of the authz.
|
||||
func (ba *baseAuthz) getStatus() string {
|
||||
return ba.Status
|
||||
}
|
||||
|
||||
// getWildcard returns true if the authz identifier has a '*', false otherwise.
|
||||
func (ba *baseAuthz) getWildcard() bool {
|
||||
return ba.Wildcard
|
||||
}
|
||||
|
||||
// getChallenges returns the authz challenge IDs.
|
||||
func (ba *baseAuthz) getChallenges() []string {
|
||||
return ba.Challenges
|
||||
}
|
||||
|
||||
// getExpiry returns the expiration time of the authz.
|
||||
func (ba *baseAuthz) getExpiry() time.Time {
|
||||
return ba.Expires
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the authz.
|
||||
func (ba *baseAuthz) getCreated() time.Time {
|
||||
return ba.Created
|
||||
}
|
||||
|
||||
// toACME converts the internal Authz type into the public acmeAuthz type for
|
||||
// presentation in the ACME protocol.
|
||||
func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) {
|
||||
var chs = make([]*Challenge, len(ba.Challenges))
|
||||
for i, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chs[i], err = ch.toACME(ctx, db, dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Authz{
|
||||
Identifier: ba.Identifier,
|
||||
Status: ba.getStatus(),
|
||||
Challenges: chs,
|
||||
Wildcard: ba.getWildcard(),
|
||||
Expires: ba.Expires.Format(time.RFC3339),
|
||||
ID: ba.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
|
||||
var (
|
||||
err error
|
||||
oldB, newB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
|
||||
}
|
||||
}
|
||||
if newB, err = json.Marshal(ba); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing authz; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) clone() *baseAuthz {
|
||||
u := *ba
|
||||
return &u
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) parent() authz {
|
||||
return &dnsAuthz{ba}
|
||||
}
|
||||
|
||||
// updateStatus attempts to update the status on a baseAuthz and stores the
|
||||
// updating object if necessary.
|
||||
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
|
||||
newAuthz := ba.clone()
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch ba.Status {
|
||||
case StatusInvalid:
|
||||
return ba.parent(), nil
|
||||
case StatusValid:
|
||||
return ba.parent(), nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(ba.Expires) {
|
||||
newAuthz.Status = StatusInvalid
|
||||
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var isValid = false
|
||||
for _, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return ba, err
|
||||
}
|
||||
if ch.getStatus() == StatusValid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return ba.parent(), nil
|
||||
}
|
||||
newAuthz.Status = StatusValid
|
||||
newAuthz.Error = nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
|
||||
}
|
||||
|
||||
if err := newAuthz.save(db, ba); err != nil {
|
||||
return ba, err
|
||||
}
|
||||
return newAuthz.parent(), nil
|
||||
}
|
||||
|
||||
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
|
||||
func unmarshalAuthz(data []byte) (authz, error) {
|
||||
var getType struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &getType); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
|
||||
}
|
||||
|
||||
switch getType.Identifier.Type {
|
||||
case "dns":
|
||||
var ba baseAuthz
|
||||
if err := json.Unmarshal(data, &ba); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
|
||||
}
|
||||
return &dnsAuthz{&ba}, nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
|
||||
getType.Identifier.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// dnsAuthz represents a dns acme authorization.
|
||||
type dnsAuthz struct {
|
||||
*baseAuthz
|
||||
}
|
||||
|
||||
// newAuthz returns a new acme authorization object based on the identifier
|
||||
// type.
|
||||
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
|
||||
switch identifier.Type {
|
||||
case "dns":
|
||||
a, err = newDNSAuthz(db, accID, identifier)
|
||||
default:
|
||||
err = MalformedErr(errors.Errorf("unexpected authz type %s",
|
||||
identifier.Type))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// newDNSAuthz returns a new dns acme authorization object.
|
||||
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
|
||||
ba, err := newBaseAuthz(accID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ba.Challenges = []string{}
|
||||
if !ba.Wildcard {
|
||||
// http and alpn challenges are only permitted if the DNS is not a wildcard dns.
|
||||
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: ba.Identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating http challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch1.getID())
|
||||
|
||||
ch2, err := newTLSALPN01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: ba.Identifier,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating alpn challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch2.getID())
|
||||
}
|
||||
ch3, err := newDNS01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating dns challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch3.getID())
|
||||
|
||||
da := &dnsAuthz{ba}
|
||||
if err := da.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
// getAuthz retrieves and unmarshals an ACME authz type from the database.
|
||||
func getAuthz(db nosql.DB, id string) (authz, error) {
|
||||
b, err := db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
|
||||
}
|
||||
az, err := unmarshalAuthz(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return az, nil
|
||||
}
|
@ -1,836 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func newAz() (authz, error) {
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
return newAuthz(mockdb, "1234", Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthz(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
az authz
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzClone(t *testing.T) {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
clone := az.clone()
|
||||
|
||||
assert.Equals(t, clone.getID(), az.getID())
|
||||
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, clone.getStatus(), az.getStatus())
|
||||
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, clone.getCreated(), az.getCreated())
|
||||
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
||||
|
||||
clone.Status = StatusValid
|
||||
|
||||
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
||||
}
|
||||
|
||||
func TestNewAuthz(t *testing.T) {
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
accID := "1234"
|
||||
type test struct {
|
||||
iden Identifier
|
||||
db nosql.DB
|
||||
err *Error
|
||||
resChs *([]string)
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
||||
err: MalformedErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"fail/new-http-chall-error": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/new-tls-alpn-chall-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/new-dns-chall-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/save-authz-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 3 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 3 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), false)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
"ok/wildcard": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
||||
return test{
|
||||
iden: _iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), true)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
// Verify that we only have 1 challenge instead of 2.
|
||||
assert.True(t, len(*chs) == 1)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
az, err := newAuthz(tc.db, accID, tc.iden)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getType(), "dns")
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
|
||||
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
||||
|
||||
if strings.HasPrefix(tc.iden.Value, "*.") {
|
||||
assert.True(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
||||
} else {
|
||||
assert.False(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
||||
}
|
||||
|
||||
assert.True(t, az.getID() != "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
|
||||
var (
|
||||
ch1, ch2 challenge
|
||||
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
ch1, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
prov := newProv()
|
||||
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
|
||||
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/getChallenge1-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"fail/getChallenge2-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 1 {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAz, err := az.toACME(ctx, tc.db, dir)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAz.ID, az.getID())
|
||||
assert.Equals(t, acmeAz.Identifier, iden)
|
||||
assert.Equals(t, acmeAz.Status, StatusPending)
|
||||
|
||||
acmeCh1, err := ch1.toACME(ctx, nil, dir)
|
||||
assert.FatalError(t, err)
|
||||
acmeCh2, err := ch2.toACME(ctx, nil, dir)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
||||
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
||||
|
||||
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzSave(t *testing.T) {
|
||||
type test struct {
|
||||
az, old authz
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAz, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAz)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: oldAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.az.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUnmarshal(t *testing.T) {
|
||||
type test struct {
|
||||
az authz
|
||||
azb []byte
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/nil": func(t *testing.T) test {
|
||||
return test{
|
||||
azb: nil,
|
||||
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
azb: b,
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok/dns": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
azb: b,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUpdateStatus(t *testing.T) {
|
||||
type test struct {
|
||||
az, res authz
|
||||
err *Error
|
||||
db nosql.DB
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/already-invalid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/already-valid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusValid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/unexpected-status": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusReady
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok/expired": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Error = MalformedErr(errors.New("authz has expired"))
|
||||
clone.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/get-challenge-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok/valid": func(t *testing.T) test {
|
||||
var (
|
||||
ch3 challenge
|
||||
ch2Bytes = &([]byte{})
|
||||
ch1Bytes = &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
} else if count == 2 {
|
||||
ch3, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Error = MalformedErr(nil)
|
||||
|
||||
_ch, ok := ch3.(*dns01Challenge)
|
||||
assert.Fatal(t, ok)
|
||||
_ch.baseChallenge.Status = StatusValid
|
||||
chb, err := json.Marshal(ch3)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Status = StatusValid
|
||||
clone.Error = nil
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
if count == 1 {
|
||||
count++
|
||||
return *ch2Bytes, nil
|
||||
}
|
||||
count++
|
||||
return chb, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/still-pending": func(t *testing.T) test {
|
||||
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
az, err := tc.az.updateStatus(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
expB, err := json.Marshal(tc.res)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expB, b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,253 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func defaultCertOps() (*CertOptions, error) {
|
||||
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CertOptions{
|
||||
AccountID: "accID",
|
||||
OrderID: "ordID",
|
||||
Leaf: crt,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newcert() (*certificate, error) {
|
||||
ops, err := defaultCertOps()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newCert(mockdb, *ops)
|
||||
}
|
||||
|
||||
func TestNewCert(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ops CertOptions
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := newCert(tc.db, tc.ops); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, *tc.id)
|
||||
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
|
||||
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: tc.ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range tc.ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, intermediates)
|
||||
|
||||
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCert(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
cert *certificate
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := getCert(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.ID, cert.ID)
|
||||
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
|
||||
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
|
||||
assert.Equals(t, tc.cert.Created, cert.Created)
|
||||
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
|
||||
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateToACME(t *testing.T) {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
acmeCert, err := cert.toACME(nil, nil)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,251 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is an error that should be used by the acme.DB interface to
|
||||
// indicate that an entity does not exist. For example, in the new-account
|
||||
// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new
|
||||
// account.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// DB is the DB interface expected by the step-ca ACME API.
|
||||
type DB interface {
|
||||
CreateAccount(ctx context.Context, acc *Account) error
|
||||
GetAccount(ctx context.Context, id string) (*Account, error)
|
||||
GetAccountByKeyID(ctx context.Context, kid string) (*Account, error)
|
||||
UpdateAccount(ctx context.Context, acc *Account) error
|
||||
|
||||
CreateNonce(ctx context.Context) (Nonce, error)
|
||||
DeleteNonce(ctx context.Context, nonce Nonce) error
|
||||
|
||||
CreateAuthorization(ctx context.Context, az *Authorization) error
|
||||
GetAuthorization(ctx context.Context, id string) (*Authorization, error)
|
||||
UpdateAuthorization(ctx context.Context, az *Authorization) error
|
||||
|
||||
CreateCertificate(ctx context.Context, cert *Certificate) error
|
||||
GetCertificate(ctx context.Context, id string) (*Certificate, error)
|
||||
|
||||
CreateChallenge(ctx context.Context, ch *Challenge) error
|
||||
GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||
UpdateChallenge(ctx context.Context, ch *Challenge) error
|
||||
|
||||
CreateOrder(ctx context.Context, o *Order) error
|
||||
GetOrder(ctx context.Context, id string) (*Order, error)
|
||||
GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
|
||||
UpdateOrder(ctx context.Context, o *Order) error
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
MockCreateAccount func(ctx context.Context, acc *Account) error
|
||||
MockGetAccount func(ctx context.Context, id string) (*Account, error)
|
||||
MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error)
|
||||
MockUpdateAccount func(ctx context.Context, acc *Account) error
|
||||
|
||||
MockCreateNonce func(ctx context.Context) (Nonce, error)
|
||||
MockDeleteNonce func(ctx context.Context, nonce Nonce) error
|
||||
|
||||
MockCreateAuthorization func(ctx context.Context, az *Authorization) error
|
||||
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
|
||||
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
|
||||
|
||||
MockCreateCertificate func(ctx context.Context, cert *Certificate) error
|
||||
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
|
||||
|
||||
MockCreateChallenge func(ctx context.Context, ch *Challenge) error
|
||||
MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||
MockUpdateChallenge func(ctx context.Context, ch *Challenge) error
|
||||
|
||||
MockCreateOrder func(ctx context.Context, o *Order) error
|
||||
MockGetOrder func(ctx context.Context, id string) (*Order, error)
|
||||
MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error)
|
||||
MockUpdateOrder func(ctx context.Context, o *Order) error
|
||||
|
||||
MockRet1 interface{}
|
||||
MockError error
|
||||
}
|
||||
|
||||
// CreateAccount mock.
|
||||
func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error {
|
||||
if m.MockCreateAccount != nil {
|
||||
return m.MockCreateAccount(ctx, acc)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAccount mock.
|
||||
func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) {
|
||||
if m.MockGetAccount != nil {
|
||||
return m.MockGetAccount(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Account), m.MockError
|
||||
}
|
||||
|
||||
// GetAccountByKeyID mock
|
||||
func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) {
|
||||
if m.MockGetAccountByKeyID != nil {
|
||||
return m.MockGetAccountByKeyID(ctx, kid)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Account), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAccount mock
|
||||
func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error {
|
||||
if m.MockUpdateAccount != nil {
|
||||
return m.MockUpdateAccount(ctx, acc)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateNonce mock
|
||||
func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) {
|
||||
if m.MockCreateNonce != nil {
|
||||
return m.MockCreateNonce(ctx)
|
||||
} else if m.MockError != nil {
|
||||
return Nonce(""), m.MockError
|
||||
}
|
||||
return m.MockRet1.(Nonce), m.MockError
|
||||
}
|
||||
|
||||
// DeleteNonce mock
|
||||
func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error {
|
||||
if m.MockDeleteNonce != nil {
|
||||
return m.MockDeleteNonce(ctx, nonce)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateAuthorization mock
|
||||
func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error {
|
||||
if m.MockCreateAuthorization != nil {
|
||||
return m.MockCreateAuthorization(ctx, az)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAuthorization mock
|
||||
func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) {
|
||||
if m.MockGetAuthorization != nil {
|
||||
return m.MockGetAuthorization(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Authorization), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAuthorization mock
|
||||
func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error {
|
||||
if m.MockUpdateAuthorization != nil {
|
||||
return m.MockUpdateAuthorization(ctx, az)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateCertificate mock
|
||||
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
|
||||
if m.MockCreateCertificate != nil {
|
||||
return m.MockCreateCertificate(ctx, cert)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetCertificate mock
|
||||
func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) {
|
||||
if m.MockGetCertificate != nil {
|
||||
return m.MockGetCertificate(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Certificate), m.MockError
|
||||
}
|
||||
|
||||
// CreateChallenge mock
|
||||
func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
|
||||
if m.MockCreateChallenge != nil {
|
||||
return m.MockCreateChallenge(ctx, ch)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetChallenge mock
|
||||
func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) {
|
||||
if m.MockGetChallenge != nil {
|
||||
return m.MockGetChallenge(ctx, chID, azID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Challenge), m.MockError
|
||||
}
|
||||
|
||||
// UpdateChallenge mock
|
||||
func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error {
|
||||
if m.MockUpdateChallenge != nil {
|
||||
return m.MockUpdateChallenge(ctx, ch)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateOrder mock
|
||||
func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error {
|
||||
if m.MockCreateOrder != nil {
|
||||
return m.MockCreateOrder(ctx, o)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetOrder mock
|
||||
func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) {
|
||||
if m.MockGetOrder != nil {
|
||||
return m.MockGetOrder(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Order), m.MockError
|
||||
}
|
||||
|
||||
// UpdateOrder mock
|
||||
func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error {
|
||||
if m.MockUpdateOrder != nil {
|
||||
return m.MockUpdateOrder(ctx, o)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetOrdersByAccountID mock
|
||||
func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
|
||||
if m.MockGetOrdersByAccountID != nil {
|
||||
return m.MockGetOrdersByAccountID(ctx, accID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.([]string), m.MockError
|
||||
}
|
@ -0,0 +1,136 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// dbAccount represents an ACME account.
|
||||
type dbAccount struct {
|
||||
ID string `json:"id"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status acme.Status `json:"status"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeactivatedAt time.Time `json:"deactivatedAt"`
|
||||
}
|
||||
|
||||
func (dba *dbAccount) clone() *dbAccount {
|
||||
nu := *dba
|
||||
return &nu
|
||||
}
|
||||
|
||||
func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
|
||||
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return "", acme.ErrNotFound
|
||||
}
|
||||
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
|
||||
}
|
||||
return string(id), nil
|
||||
}
|
||||
|
||||
// getDBAccount retrieves and unmarshals dbAccount.
|
||||
func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
|
||||
data, err := db.db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return nil, acme.ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrapf(err, "error loading account %s", id)
|
||||
}
|
||||
|
||||
dbacc := new(dbAccount)
|
||||
if err = json.Unmarshal(data, dbacc); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id)
|
||||
}
|
||||
return dbacc, nil
|
||||
}
|
||||
|
||||
// GetAccount retrieves an ACME account by ID.
|
||||
func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
|
||||
dbacc, err := db.getDBAccount(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &acme.Account{
|
||||
Status: dbacc.Status,
|
||||
Contact: dbacc.Contact,
|
||||
Key: dbacc.Key,
|
||||
ID: dbacc.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
|
||||
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
|
||||
id, err := db.getAccountIDByKeyID(ctx, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.GetAccount(ctx, id)
|
||||
}
|
||||
|
||||
// CreateAccount imlements the AcmeDB.CreateAccount interface.
|
||||
func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
||||
var err error
|
||||
acc.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dba := &dbAccount{
|
||||
ID: acc.ID,
|
||||
Key: acc.Key,
|
||||
Contact: acc.Contact,
|
||||
Status: acc.Status,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
kid, err := acme.KeyToID(dba.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kidB := []byte(kid)
|
||||
|
||||
// Set the jwkID -> acme account ID index
|
||||
_, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID))
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrap(err, "error storing keyID to accountID index")
|
||||
case !swapped:
|
||||
return errors.Errorf("key-id to account-id index already exists")
|
||||
default:
|
||||
if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil {
|
||||
db.db.Del(accountByKeyIDTable, kidB)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
|
||||
func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
|
||||
old, err := db.getDBAccount(ctx, acc.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.Contact = acc.Contact
|
||||
nu.Status = acc.Status
|
||||
|
||||
// If the status has changed to 'deactivated', then set deactivatedAt timestamp.
|
||||
if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated {
|
||||
nu.DeactivatedAt = clock.Now()
|
||||
}
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "account", accountTable)
|
||||
}
|
@ -0,0 +1,706 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
func TestDB_getDBAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbacc *dbAccount
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: acme.ErrNotFound,
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling account accID into dbAccount"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbacc: dbacc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||
accID := "accID"
|
||||
kid := "kid"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: acme.ErrNotFound,
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading key-account index for key kid: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return []byte(accID), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, retAccID, accID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbacc *dbAccount
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbacc: dbacc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if acc, err := db.GetAccount(context.Background(), accID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||
accID := "accID"
|
||||
kid := "kid"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbacc *dbAccount
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.getAccountIDByKeyID-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, string(bucket), string(accountByKeyIDTable))
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading key-account index for key kid: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetAccount-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), kid)
|
||||
return []byte(accID), nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), accID)
|
||||
return nil, errors.New("force")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), kid)
|
||||
return []byte(accID), nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), accID)
|
||||
return b, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
dbacc: dbacc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateAccount(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
acc *acme.Account
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/keyID-cmpAndSwap-error": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
assert.Equals(t, nu, []byte(acc.ID))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
err: errors.New("error storing keyID to accountID index: force"),
|
||||
}
|
||||
},
|
||||
"fail/keyID-cmpAndSwap-false": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
assert.Equals(t, nu, []byte(acc.ID))
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
err: errors.New("key-id to account-id index already exists"),
|
||||
}
|
||||
},
|
||||
"fail/account-save-error": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
return nu, true, nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), acc.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbacc := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||
assert.Equals(t, dbacc.ID, string(key))
|
||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||
return nil, false, errors.New("force")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
err: errors.New("error saving acme account: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
id = string(key)
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
return nu, true, nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), acc.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbacc := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||
assert.Equals(t, dbacc.ID, string(key))
|
||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||
return nu, true, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_UpdateAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
acc *acme.Account
|
||||
err error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
acc: &acme.Account{
|
||||
ID: accID,
|
||||
},
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"fail/already-deactivated": func(t *testing.T) test {
|
||||
clone := dbacc.clone()
|
||||
clone.Status = acme.StatusDeactivated
|
||||
clone.DeactivatedAt = now
|
||||
dbaccb, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return dbaccb, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbNew := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, clone.ID)
|
||||
assert.Equals(t, dbNew.Status, clone.Status)
|
||||
assert.Equals(t, dbNew.Contact, clone.Contact)
|
||||
assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt)
|
||||
assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme account: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbNew := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||
assert.Equals(t, dbNew.Status, acc.Status)
|
||||
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme account: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbNew := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||
assert.Equals(t, dbNew.Status, acc.Status)
|
||||
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, dbacc.ID)
|
||||
assert.Equals(t, tc.acc.Status, dbacc.Status)
|
||||
assert.Equals(t, tc.acc.Contact, dbacc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// dbAuthz is the base authz type that others build from.
|
||||
type dbAuthz struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Identifier acme.Identifier `json:"identifier"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
ChallengeIDs []string `json:"challengeIDs"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
Error *acme.Error `json:"error"`
|
||||
}
|
||||
|
||||
func (ba *dbAuthz) clone() *dbAuthz {
|
||||
u := *ba
|
||||
return &u
|
||||
}
|
||||
|
||||
// getDBAuthz retrieves and unmarshals a database representation of the
|
||||
// ACME Authorization type.
|
||||
func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) {
|
||||
data, err := db.db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading authz %s", id)
|
||||
}
|
||||
|
||||
var dbaz dbAuthz
|
||||
if err = json.Unmarshal(data, &dbaz); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id)
|
||||
}
|
||||
return &dbaz, nil
|
||||
}
|
||||
|
||||
// GetAuthorization retrieves and unmarshals an ACME authz type from the database.
|
||||
// Implements acme.DB GetAuthorization interface.
|
||||
func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||
dbaz, err := db.getDBAuthz(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs))
|
||||
for i, chID := range dbaz.ChallengeIDs {
|
||||
chs[i], err = db.GetChallenge(ctx, chID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &acme.Authorization{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: chs,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Error: dbaz.Error,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateAuthorization creates an entry in the database for the Authorization.
|
||||
// Implements the acme.DB.CreateAuthorization interface.
|
||||
func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
var err error
|
||||
az.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chIDs := make([]string, len(az.Challenges))
|
||||
for i, ch := range az.Challenges {
|
||||
chIDs[i] = ch.ID
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: az.ID,
|
||||
AccountID: az.AccountID,
|
||||
Status: az.Status,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: az.ExpiresAt,
|
||||
Identifier: az.Identifier,
|
||||
ChallengeIDs: chIDs,
|
||||
Token: az.Token,
|
||||
Wildcard: az.Wildcard,
|
||||
}
|
||||
|
||||
return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable)
|
||||
}
|
||||
|
||||
// UpdateAuthorization saves an updated ACME Authorization to the database.
|
||||
func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
old, err := db.getDBAuthz(ctx, az.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
nu.Status = az.Status
|
||||
nu.Error = az.Error
|
||||
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
||||
}
|
@ -0,0 +1,620 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_getDBAuthz(t *testing.T) {
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbaz *dbAuthz
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading authz azID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling authz azID into dbAuthz"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbaz: dbaz,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
||||
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAuthorization(t *testing.T) {
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbaz *dbAuthz
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading authz azID: force"),
|
||||
}
|
||||
},
|
||||
"fail/forward-acme-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(authzTable):
|
||||
assert.Equals(t, string(key), azID)
|
||||
return b, nil
|
||||
case string(challengeTable):
|
||||
assert.Equals(t, string(key), "foo")
|
||||
return nil, errors.New("force")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge foo: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetChallenge-not-found": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(authzTable):
|
||||
assert.Equals(t, string(key), azID)
|
||||
return b, nil
|
||||
case string(challengeTable):
|
||||
assert.Equals(t, string(key), "foo")
|
||||
return nil, nosqldb.ErrNotFound
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
chCount := 0
|
||||
fooChb, err := json.Marshal(&dbChallenge{ID: "foo"})
|
||||
assert.FatalError(t, err)
|
||||
barChb, err := json.Marshal(&dbChallenge{ID: "bar"})
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(authzTable):
|
||||
assert.Equals(t, string(key), azID)
|
||||
return b, nil
|
||||
case string(challengeTable):
|
||||
if chCount == 0 {
|
||||
chCount++
|
||||
assert.Equals(t, string(key), "foo")
|
||||
return fooChb, nil
|
||||
}
|
||||
assert.Equals(t, string(key), "bar")
|
||||
return barChb, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
dbaz: dbaz,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if az, err := db.GetAuthorization(context.Background(), azID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, az.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, az.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
||||
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
})
|
||||
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateAuthorization(t *testing.T) {
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
az *acme.Authorization
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &acme.Authorization{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Wildcard: true,
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), az.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbaz := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbaz))
|
||||
assert.Equals(t, dbaz.ID, string(key))
|
||||
assert.Equals(t, dbaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
})
|
||||
assert.Equals(t, dbaz.Status, az.Status)
|
||||
assert.Equals(t, dbaz.Token, az.Token)
|
||||
assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
|
||||
assert.Equals(t, dbaz.Wildcard, az.Wildcard)
|
||||
assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
|
||||
assert.Nil(t, dbaz.Error)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
az: az,
|
||||
err: errors.New("error saving acme authz: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
now = clock.Now()
|
||||
az = &acme.Authorization{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Wildcard: true,
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), az.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbaz := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbaz))
|
||||
assert.Equals(t, dbaz.ID, string(key))
|
||||
assert.Equals(t, dbaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
})
|
||||
assert.Equals(t, dbaz.Status, az.Status)
|
||||
assert.Equals(t, dbaz.Token, az.Token)
|
||||
assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
|
||||
assert.Equals(t, dbaz.Wildcard, az.Wildcard)
|
||||
assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
|
||||
assert.Nil(t, dbaz.Error)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
az: az,
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_UpdateAuthorization(t *testing.T) {
|
||||
azID := "azID"
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
az *acme.Authorization
|
||||
err error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
az: &acme.Authorization{
|
||||
ID: azID,
|
||||
},
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading authz azID: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||
updAz := &acme.Authorization{
|
||||
ID: azID,
|
||||
Status: acme.StatusValid,
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
az: updAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbaz, dbOld)
|
||||
|
||||
dbNew := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbaz.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
|
||||
assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.Token, dbaz.Token)
|
||||
assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
|
||||
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme authz: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
updAz := &acme.Authorization{
|
||||
ID: azID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Status: acme.StatusValid,
|
||||
Identifier: dbaz.Identifier,
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Token: dbaz.Token,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
az: updAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbaz, dbOld)
|
||||
|
||||
dbNew := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbaz.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
|
||||
assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.Token, dbaz.Token)
|
||||
assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
|
||||
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.ID, dbaz.ID)
|
||||
assert.Equals(t, tc.az.AccountID, dbaz.AccountID)
|
||||
assert.Equals(t, tc.az.Identifier, dbaz.Identifier)
|
||||
assert.Equals(t, tc.az.Status, acme.StatusValid)
|
||||
assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, tc.az.Token, dbaz.Token)
|
||||
assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, tc.az.Challenges, []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
})
|
||||
assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,109 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type dbCert struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
AccountID string `json:"accountID"`
|
||||
OrderID string `json:"orderID"`
|
||||
Leaf []byte `json:"leaf"`
|
||||
Intermediates []byte `json:"intermediates"`
|
||||
}
|
||||
|
||||
// CreateCertificate creates and stores an ACME certificate type.
|
||||
func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
|
||||
var err error
|
||||
cert.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range cert.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
dbch := &dbCert{
|
||||
ID: cert.ID,
|
||||
AccountID: cert.AccountID,
|
||||
OrderID: cert.OrderID,
|
||||
Leaf: leaf,
|
||||
Intermediates: intermediates,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
|
||||
}
|
||||
|
||||
// GetCertificate retrieves and unmarshals an ACME certificate type from the
|
||||
// datastore.
|
||||
func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
|
||||
b, err := db.db.Get(certTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading certificate %s", id)
|
||||
}
|
||||
dbC := new(dbCert)
|
||||
if err := json.Unmarshal(b, dbC); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id)
|
||||
}
|
||||
|
||||
certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id)
|
||||
}
|
||||
|
||||
return &acme.Certificate{
|
||||
ID: dbC.ID,
|
||||
AccountID: dbC.AccountID,
|
||||
OrderID: dbC.OrderID,
|
||||
Leaf: certs[0],
|
||||
Intermediates: certs[1:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseBundle(b []byte) ([]*x509.Certificate, error) {
|
||||
var (
|
||||
err error
|
||||
block *pem.Block
|
||||
bundle []*x509.Certificate
|
||||
)
|
||||
for len(b) > 0 {
|
||||
block, b = pem.Decode(b)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, errors.New("error decoding PEM: data contains block that is not a certificate")
|
||||
}
|
||||
var crt *x509.Certificate
|
||||
crt, err = x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing x509 certificate")
|
||||
}
|
||||
bundle = append(bundle, crt)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
return nil, errors.New("error decoding PEM: unexpected data")
|
||||
}
|
||||
return bundle, nil
|
||||
|
||||
}
|
@ -0,0 +1,321 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestDB_CreateCertificate(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
cert *acme.Certificate
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
cert := &acme.Certificate{
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: leaf,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbCert)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.ID, cert.ID)
|
||||
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
cert: cert,
|
||||
err: errors.New("error saving acme certificate: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert := &acme.Certificate{
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: leaf,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbCert)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.ID, cert.ID)
|
||||
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
_id: idPtr,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetCertificate(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
certID := "certID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
return nil, errors.Errorf("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading certificate certID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
return []byte("foobar"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling certificate certID"),
|
||||
}
|
||||
},
|
||||
"fail/parseBundle-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
cert := dbCert{
|
||||
ID: certID,
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "Public Key",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
cert := dbCert{
|
||||
ID: certID,
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
Intermediates: append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: inter.Raw,
|
||||
}), pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})...),
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
cert, err := db.GetCertificate(context.Background(), certID)
|
||||
if err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, certID)
|
||||
assert.Equals(t, cert.AccountID, "accountID")
|
||||
assert.Equals(t, cert.OrderID, "orderID")
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseBundle(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
var certs []byte
|
||||
for _, cert := range []*x509.Certificate{leaf, inter, root} {
|
||||
certs = append(certs, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
type test struct {
|
||||
b []byte
|
||||
err error
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"fail/bad-type-error": {
|
||||
b: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "Public Key",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"),
|
||||
},
|
||||
"fail/bad-pem-error": {
|
||||
b: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: []byte("foo"),
|
||||
}),
|
||||
err: errors.Errorf("error parsing x509 certificate"),
|
||||
},
|
||||
"fail/unexpected-data": {
|
||||
b: append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}), []byte("foo")...),
|
||||
err: errors.Errorf("error decoding PEM: unexpected data"),
|
||||
},
|
||||
"ok": {
|
||||
b: certs,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ret, err := parseBundle(tc.b)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,103 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type dbChallenge struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Type string `json:"type"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Value string `json:"value"`
|
||||
ValidatedAt string `json:"validatedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Error *acme.Error `json:"error"`
|
||||
}
|
||||
|
||||
func (dbc *dbChallenge) clone() *dbChallenge {
|
||||
u := *dbc
|
||||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
|
||||
data, err := db.db.Get(challengeTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading acme challenge %s", id)
|
||||
}
|
||||
|
||||
dbch := new(dbChallenge)
|
||||
if err := json.Unmarshal(data, dbch); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling dbChallenge")
|
||||
}
|
||||
return dbch, nil
|
||||
}
|
||||
|
||||
// CreateChallenge creates a new ACME challenge data structure in the database.
|
||||
// Implements acme.DB.CreateChallenge interface.
|
||||
func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
||||
var err error
|
||||
ch.ID, err = randID()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error generating random id for ACME challenge")
|
||||
}
|
||||
|
||||
dbch := &dbChallenge{
|
||||
ID: ch.ID,
|
||||
AccountID: ch.AccountID,
|
||||
Value: ch.Value,
|
||||
Status: acme.StatusPending,
|
||||
Token: ch.Token,
|
||||
CreatedAt: clock.Now(),
|
||||
Type: ch.Type,
|
||||
}
|
||||
|
||||
return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable)
|
||||
}
|
||||
|
||||
// GetChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||
// Implements the acme.DB GetChallenge interface.
|
||||
func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
|
||||
dbch, err := db.getDBChallenge(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := &acme.Challenge{
|
||||
ID: dbch.ID,
|
||||
AccountID: dbch.AccountID,
|
||||
Type: dbch.Type,
|
||||
Value: dbch.Value,
|
||||
Status: dbch.Status,
|
||||
Token: dbch.Token,
|
||||
Error: dbch.Error,
|
||||
ValidatedAt: dbch.ValidatedAt,
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// UpdateChallenge updates an ACME challenge type in the database.
|
||||
func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
||||
old, err := db.getDBChallenge(ctx, ch.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
// These should be the only values changing in an Update request.
|
||||
nu.Status = ch.Status
|
||||
nu.Error = ch.Error
|
||||
nu.ValidatedAt = ch.ValidatedAt
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
|
||||
}
|
@ -0,0 +1,464 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_getDBChallenge(t *testing.T) {
|
||||
chID := "chID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbc *dbChallenge
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge chID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling dbChallenge"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dbc := &dbChallenge{
|
||||
ID: chID,
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbc: dbc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateChallenge(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ch *acme.Challenge
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
ch := &acme.Challenge{
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), ch.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
||||
assert.Equals(t, dbc.Type, ch.Type)
|
||||
assert.Equals(t, dbc.Status, ch.Status)
|
||||
assert.Equals(t, dbc.Token, ch.Token)
|
||||
assert.Equals(t, dbc.Value, ch.Value)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
ch: ch,
|
||||
err: errors.New("error saving acme challenge: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
ch = &acme.Challenge{
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
}
|
||||
)
|
||||
|
||||
return test{
|
||||
ch: ch,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), ch.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
||||
assert.Equals(t, dbc.Type, ch.Type)
|
||||
assert.Equals(t, dbc.Status, ch.Status)
|
||||
assert.Equals(t, dbc.Token, ch.Token)
|
||||
assert.Equals(t, dbc.Value, ch.Value)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.ch.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetChallenge(t *testing.T) {
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbc *dbChallenge
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge chID: force"),
|
||||
}
|
||||
},
|
||||
"fail/forward-acme-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dbc := &dbChallenge{
|
||||
ID: chID,
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbc: dbc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_UpdateChallenge(t *testing.T) {
|
||||
chID := "chID"
|
||||
dbc := &dbChallenge{
|
||||
ID: chID,
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ch *acme.Challenge
|
||||
err error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
ch: &acme.Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge chID: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||
updCh := &acme.Challenge{
|
||||
ID: chID,
|
||||
Status: acme.StatusValid,
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
ch: updCh,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbc, dbOld)
|
||||
|
||||
dbNew := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbc.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
||||
assert.Equals(t, dbNew.Type, dbc.Type)
|
||||
assert.Equals(t, dbNew.Status, updCh.Status)
|
||||
assert.Equals(t, dbNew.Token, dbc.Token)
|
||||
assert.Equals(t, dbNew.Value, dbc.Value)
|
||||
assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
|
||||
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||
assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme challenge: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
updCh := &acme.Challenge{
|
||||
ID: dbc.ID,
|
||||
AccountID: dbc.AccountID,
|
||||
Type: dbc.Type,
|
||||
Token: dbc.Token,
|
||||
Value: dbc.Value,
|
||||
Status: acme.StatusValid,
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
ch: updCh,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbc, dbOld)
|
||||
|
||||
dbNew := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbc.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
||||
assert.Equals(t, dbNew.Type, dbc.Type)
|
||||
assert.Equals(t, dbNew.Token, dbc.Token)
|
||||
assert.Equals(t, dbNew.Value, dbc.Value)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.ValidatedAt, "foobar")
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.ch.ID, dbc.ID)
|
||||
assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
|
||||
assert.Equals(t, tc.ch.Type, dbc.Type)
|
||||
assert.Equals(t, tc.ch.Token, dbc.Token)
|
||||
assert.Equals(t, tc.ch.Value, dbc.Value)
|
||||
assert.Equals(t, tc.ch.ValidatedAt, "foobar")
|
||||
assert.Equals(t, tc.ch.Status, acme.StatusValid)
|
||||
assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// dbNonce contains nonce metadata used in the ACME protocol.
|
||||
type dbNonce struct {
|
||||
ID string
|
||||
CreatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
}
|
||||
|
||||
// CreateNonce creates, stores, and returns an ACME replay-nonce.
|
||||
// Implements the acme.DB interface.
|
||||
func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
||||
_id, err := randID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||
n := &dbNonce{
|
||||
ID: id,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return acme.Nonce(id), nil
|
||||
}
|
||||
|
||||
// DeleteNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
|
||||
err := db.db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Get,
|
||||
},
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Delete,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce))
|
||||
case err != nil:
|
||||
return errors.Wrapf(err, "error deleting nonce %s", string(nonce))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
@ -0,0 +1,168 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_CreateNonce(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbn := new(dbNonce)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbn))
|
||||
assert.Equals(t, dbn.ID, string(key))
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme nonce: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbn := new(dbNonce)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbn))
|
||||
assert.Equals(t, dbn.ID, string(key))
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if n, err := db.CreateNonce(context.Background()); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, string(n), *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_DeleteNonce(t *testing.T) {
|
||||
|
||||
nonceID := "nonceID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return database.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID),
|
||||
}
|
||||
},
|
||||
"fail/db.Update-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error deleting nonce nonceID: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/randutil"
|
||||
)
|
||||
|
||||
var (
|
||||
accountTable = []byte("acme_accounts")
|
||||
accountByKeyIDTable = []byte("acme_keyID_accountID_index")
|
||||
authzTable = []byte("acme_authzs")
|
||||
challengeTable = []byte("acme_challenges")
|
||||
nonceTable = []byte("nonces")
|
||||
orderTable = []byte("acme_orders")
|
||||
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
||||
certTable = []byte("acme_certs")
|
||||
)
|
||||
|
||||
// DB is a struct that implements the AcmeDB interface.
|
||||
type DB struct {
|
||||
db nosqlDB.DB
|
||||
}
|
||||
|
||||
// New configures and returns a new ACME DB backend implemented using a nosql DB.
|
||||
func New(db nosqlDB.DB) (*DB, error) {
|
||||
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
||||
challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
string(b))
|
||||
}
|
||||
}
|
||||
return &DB{db}, nil
|
||||
}
|
||||
|
||||
// save writes the new data to the database, overwriting the old data if it
|
||||
// existed.
|
||||
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
||||
var (
|
||||
err error
|
||||
newB []byte
|
||||
)
|
||||
if nu == nil {
|
||||
newB = nil
|
||||
} else {
|
||||
newB, err = json.Marshal(nu)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu)
|
||||
}
|
||||
}
|
||||
var oldB []byte
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
oldB, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
|
||||
}
|
||||
}
|
||||
|
||||
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrapf(err, "error saving acme %s", typ)
|
||||
case !swapped:
|
||||
return errors.Errorf("error saving acme %s; changed since last read", typ)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var idLen = 32
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.Alphanumeric(idLen)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Second)
|
||||
}
|
||||
|
||||
var clock = new(Clock)
|
@ -0,0 +1,139 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"fail/db.CreateTable-error": {
|
||||
db: &db.MockNoSQLDB{
|
||||
MCreateTable: func(bucket []byte) error {
|
||||
assert.Equals(t, string(bucket), string(accountTable))
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.Errorf("error creating table %s: force", string(accountTable)),
|
||||
},
|
||||
"ok": {
|
||||
db: &db.MockNoSQLDB{
|
||||
MCreateTable: func(bucket []byte) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if _, err := New(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errorThrower string
|
||||
|
||||
func (et errorThrower) MarshalJSON() ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
|
||||
func TestDB_save(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
nu interface{}
|
||||
old interface{}
|
||||
err error
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"fail/error-marshaling-new": {
|
||||
nu: errorThrower("foo"),
|
||||
err: errors.New("error marshaling acme type: challenge"),
|
||||
},
|
||||
"fail/error-marshaling-old": {
|
||||
nu: "new",
|
||||
old: errorThrower("foo"),
|
||||
err: errors.New("error marshaling acme type: challenge"),
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": {
|
||||
nu: "new",
|
||||
old: "old",
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, string(old), "\"old\"")
|
||||
assert.Equals(t, string(nu), "\"new\"")
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme challenge: force"),
|
||||
},
|
||||
"fail/db.CmpAndSwap-false-marshaling-old": {
|
||||
nu: "new",
|
||||
old: "old",
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, string(old), "\"old\"")
|
||||
assert.Equals(t, string(nu), "\"new\"")
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme challenge; changed since last read"),
|
||||
},
|
||||
"ok": {
|
||||
nu: "new",
|
||||
old: "old",
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, string(old), "\"old\"")
|
||||
assert.Equals(t, string(nu), "\"new\"")
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
"ok/nils": {
|
||||
nu: nil,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, nu, nil)
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := &DB{db: tc.db}
|
||||
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,189 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Mutex for locking ordersByAccount index operations.
|
||||
var ordersByAccountMux sync.Mutex
|
||||
|
||||
type dbOrder struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Identifiers []acme.Identifier `json:"identifiers"`
|
||||
AuthorizationIDs []string `json:"authorizationIDs"`
|
||||
Status acme.Status `json:"status"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt,omitempty"`
|
||||
CertificateID string `json:"certificate,omitempty"`
|
||||
Error *acme.Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (a *dbOrder) clone() *dbOrder {
|
||||
b := *a
|
||||
return &b
|
||||
}
|
||||
|
||||
// getDBOrder retrieves and unmarshals an ACME Order type from the database.
|
||||
func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
|
||||
b, err := db.db.Get(orderTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading order %s", id)
|
||||
}
|
||||
o := new(dbOrder)
|
||||
if err := json.Unmarshal(b, &o); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// GetOrder retrieves an ACME Order from the database.
|
||||
func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
|
||||
dbo, err := db.getDBOrder(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o := &acme.Order{
|
||||
ID: dbo.ID,
|
||||
AccountID: dbo.AccountID,
|
||||
ProvisionerID: dbo.ProvisionerID,
|
||||
CertificateID: dbo.CertificateID,
|
||||
Status: dbo.Status,
|
||||
ExpiresAt: dbo.ExpiresAt,
|
||||
Identifiers: dbo.Identifiers,
|
||||
NotBefore: dbo.NotBefore,
|
||||
NotAfter: dbo.NotAfter,
|
||||
AuthorizationIDs: dbo.AuthorizationIDs,
|
||||
Error: dbo.Error,
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// CreateOrder creates ACME Order resources and saves them to the DB.
|
||||
func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
|
||||
var err error
|
||||
o.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
dbo := &dbOrder{
|
||||
ID: o.ID,
|
||||
AccountID: o.AccountID,
|
||||
ProvisionerID: o.ProvisionerID,
|
||||
Status: o.Status,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: o.ExpiresAt,
|
||||
Identifiers: o.Identifiers,
|
||||
NotBefore: o.NotBefore,
|
||||
NotAfter: o.NotAfter,
|
||||
AuthorizationIDs: o.AuthorizationIDs,
|
||||
}
|
||||
if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOrder saves an updated ACME Order to the database.
|
||||
func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
|
||||
old, err := db.getDBOrder(ctx, o.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
nu.Status = o.Status
|
||||
nu.Error = o.Error
|
||||
nu.CertificateID = o.CertificateID
|
||||
return db.save(ctx, old.ID, nu, old, "order", orderTable)
|
||||
}
|
||||
|
||||
func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) {
|
||||
ordersByAccountMux.Lock()
|
||||
defer ordersByAccountMux.Unlock()
|
||||
|
||||
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
|
||||
var (
|
||||
oldOids []string
|
||||
)
|
||||
if err != nil {
|
||||
if !nosql.IsErrNotFound(err) {
|
||||
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(b, &oldOids); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any order that is not in PENDING state and update the stored list
|
||||
// before returning.
|
||||
//
|
||||
// According to RFC 8555:
|
||||
// The server SHOULD include pending orders and SHOULD NOT include orders
|
||||
// that are invalid in the array of URLs.
|
||||
pendOids := []string{}
|
||||
for _, oid := range oldOids {
|
||||
o, err := db.GetOrder(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID)
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID)
|
||||
}
|
||||
if o.Status == acme.StatusPending {
|
||||
pendOids = append(pendOids, oid)
|
||||
}
|
||||
}
|
||||
pendOids = append(pendOids, addOids...)
|
||||
var (
|
||||
_old interface{} = oldOids
|
||||
_new interface{} = pendOids
|
||||
)
|
||||
switch {
|
||||
case len(oldOids) == 0 && len(pendOids) == 0:
|
||||
// If list has not changed from empty, then no need to write the DB.
|
||||
return []string{}, nil
|
||||
case len(oldOids) == 0:
|
||||
_old = nil
|
||||
case len(pendOids) == 0:
|
||||
_new = nil
|
||||
}
|
||||
if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil {
|
||||
// Delete all orders that may have been previously stored if orderIDsByAccountID update fails.
|
||||
for _, oid := range addOids {
|
||||
// Ignore error from delete -- we tried our best.
|
||||
// TODO when we have logging w/ request ID tracking, logging this error.
|
||||
db.db.Del(orderTable, []byte(oid))
|
||||
}
|
||||
return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID)
|
||||
}
|
||||
return pendOids, nil
|
||||
}
|
||||
|
||||
// GetOrdersByAccountID returns a list of order IDs owned by the account.
|
||||
func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
|
||||
return db.updateAddOrderIDs(ctx, accID)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,150 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Directory represents an ACME directory for configuring clients.
|
||||
type Directory struct {
|
||||
NewNonce string `json:"newNonce,omitempty"`
|
||||
NewAccount string `json:"newAccount,omitempty"`
|
||||
NewOrder string `json:"newOrder,omitempty"`
|
||||
NewAuthz string `json:"newAuthz,omitempty"`
|
||||
RevokeCert string `json:"revokeCert,omitempty"`
|
||||
KeyChange string `json:"keyChange,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging for the Directory type.
|
||||
func (d *Directory) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
type directory struct {
|
||||
prefix, dns string
|
||||
}
|
||||
|
||||
// newDirectory returns a new Directory type.
|
||||
func newDirectory(dns, prefix string) *directory {
|
||||
return &directory{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Link captures the link type.
|
||||
type Link int
|
||||
|
||||
const (
|
||||
// NewNonceLink new-nonce
|
||||
NewNonceLink Link = iota
|
||||
// NewAccountLink new-account
|
||||
NewAccountLink
|
||||
// AccountLink account
|
||||
AccountLink
|
||||
// OrderLink order
|
||||
OrderLink
|
||||
// NewOrderLink new-order
|
||||
NewOrderLink
|
||||
// OrdersByAccountLink list of orders owned by account
|
||||
OrdersByAccountLink
|
||||
// FinalizeLink finalize order
|
||||
FinalizeLink
|
||||
// NewAuthzLink authz
|
||||
NewAuthzLink
|
||||
// AuthzLink new-authz
|
||||
AuthzLink
|
||||
// ChallengeLink challenge
|
||||
ChallengeLink
|
||||
// CertificateLink certificate
|
||||
CertificateLink
|
||||
// DirectoryLink directory
|
||||
DirectoryLink
|
||||
// RevokeCertLink revoke certificate
|
||||
RevokeCertLink
|
||||
// KeyChangeLink key rollover
|
||||
KeyChangeLink
|
||||
)
|
||||
|
||||
func (l Link) String() string {
|
||||
switch l {
|
||||
case NewNonceLink:
|
||||
return "new-nonce"
|
||||
case NewAccountLink:
|
||||
return "new-account"
|
||||
case AccountLink:
|
||||
return "account"
|
||||
case NewOrderLink:
|
||||
return "new-order"
|
||||
case OrderLink:
|
||||
return "order"
|
||||
case NewAuthzLink:
|
||||
return "new-authz"
|
||||
case AuthzLink:
|
||||
return "authz"
|
||||
case ChallengeLink:
|
||||
return "challenge"
|
||||
case CertificateLink:
|
||||
return "certificate"
|
||||
case DirectoryLink:
|
||||
return "directory"
|
||||
case RevokeCertLink:
|
||||
return "revoke-cert"
|
||||
case KeyChangeLink:
|
||||
return "key-change"
|
||||
default:
|
||||
return "unexpected"
|
||||
}
|
||||
}
|
||||
|
||||
func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
|
||||
var provName string
|
||||
if p, err := ProvisionerFromContext(ctx); err == nil && p != nil {
|
||||
provName = p.GetName()
|
||||
}
|
||||
return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...)
|
||||
}
|
||||
|
||||
// getLinkExplicit returns an absolute or partial path to the given resource and a base
|
||||
// URL dynamically obtained from the request for which the link is being
|
||||
// calculated.
|
||||
func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
|
||||
var link string
|
||||
switch typ {
|
||||
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
|
||||
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
|
||||
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
|
||||
case OrdersByAccountLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
|
||||
case FinalizeLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
|
||||
}
|
||||
|
||||
if abs {
|
||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||
u := url.URL{}
|
||||
if baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
|
||||
// If no Scheme is set, then default to https.
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||
if u.Host == "" {
|
||||
u.Host = d.dns
|
||||
}
|
||||
|
||||
u.Path = d.prefix + link
|
||||
return u.String()
|
||||
}
|
||||
return link
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestDirectoryGetLink(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
dir := newDirectory(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
|
||||
|
||||
assert.Equals(t, dir.getLink(ctx, NewNonceLink, true),
|
||||
fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName))
|
||||
assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
|
||||
|
||||
// No provisioner
|
||||
ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL)
|
||||
assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true),
|
||||
fmt.Sprintf("%s/acme//new-nonce", baseURL.String()))
|
||||
assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce")
|
||||
|
||||
// No baseURL
|
||||
ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
||||
assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true),
|
||||
fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName))
|
||||
assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
|
||||
|
||||
assert.Equals(t, dir.getLink(ctx, OrderLink, true, id),
|
||||
fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName))
|
||||
assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName))
|
||||
}
|
||||
|
||||
func TestDirectoryGetLinkExplicit(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prefix := "acme"
|
||||
dir := newDirectory(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provID := url.PathEscape(prov.GetName())
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID))
|
||||
assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID))
|
||||
}
|
@ -1,73 +1,9 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
// Nonce represents an ACME nonce type.
|
||||
type Nonce string
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// nonce contains nonce metadata used in the ACME protocol.
|
||||
type nonce struct {
|
||||
ID string
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// newNonce creates, stores, and returns an ACME replay-nonce.
|
||||
func newNonce(db nosql.DB) (*nonce, error) {
|
||||
_id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||
n := &nonce{
|
||||
ID: id,
|
||||
Created: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing nonce; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
// useNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func useNonce(db nosql.DB, nonce string) error {
|
||||
err := db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Get,
|
||||
},
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Delete,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
return BadNonceErr(nil)
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
// String implements the ToString interface.
|
||||
func (n Nonce) String() string {
|
||||
return string(n)
|
||||
}
|
||||
|
@ -1,163 +0,0 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestNewNonce(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if n, err := newNonce(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, n.ID, *tc.id)
|
||||
|
||||
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseNonce(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/update-not-found": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return database.ErrNotFound
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: BadNonceErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/update-error": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := useNonce(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
||||
package acme
|
||||
|
||||
// Status represents an ACME status.
|
||||
type Status string
|
||||
|
||||
var (
|
||||
// StatusValid -- valid
|
||||
StatusValid = Status("valid")
|
||||
// StatusInvalid -- invalid
|
||||
StatusInvalid = Status("invalid")
|
||||
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||
StatusPending = Status("pending")
|
||||
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||
StatusDeactivated = Status("deactivated")
|
||||
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||
StatusReady = Status("ready")
|
||||
//statusExpired = "expired"
|
||||
//statusActive = "active"
|
||||
//statusProcessing = "processing"
|
||||
)
|
Loading…
Reference in New Issue