Add ACME CA capabilities
parent
68ab03dc1b
commit
e3826dd1c3
@ -0,0 +1,214 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Account is a subset of the internal account type containing only those
|
||||
// attributes required for responses in the ACME protocol.
|
||||
type Account struct {
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Orders string `json:"orders"`
|
||||
ID string `json:"-"`
|
||||
Key *jose.JSONWebKey `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Account) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the account ID.
|
||||
func (a *Account) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// GetKey returns the JWK associated with the account.
|
||||
func (a *Account) GetKey() *jose.JSONWebKey {
|
||||
return a.Key
|
||||
}
|
||||
|
||||
// IsValid returns true if the Account is valid.
|
||||
func (a *Account) IsValid() bool {
|
||||
return a.Status == StatusValid
|
||||
}
|
||||
|
||||
// AccountOptions are the options needed to create a new ACME account.
|
||||
type AccountOptions struct {
|
||||
Key *jose.JSONWebKey
|
||||
Contact []string
|
||||
}
|
||||
|
||||
// account represents an ACME account.
|
||||
type account struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
Deactivated time.Time `json:"deactivated"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// newAccount returns a new acme account type.
|
||||
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &account{
|
||||
ID: id,
|
||||
Key: ops.Key,
|
||||
Contact: ops.Contact,
|
||||
Status: "valid",
|
||||
Created: clock.Now(),
|
||||
}
|
||||
return a, a.saveNew(db)
|
||||
}
|
||||
|
||||
// toACME converts the internal Account type into the public acmeAccount
|
||||
// type for presentation in the ACME protocol.
|
||||
func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) {
|
||||
return &Account{
|
||||
Status: a.Status,
|
||||
Contact: a.Contact,
|
||||
Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID),
|
||||
Key: a.Key,
|
||||
ID: a.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// save writes the Account to the DB.
|
||||
// If the account is new then the necessary indices will be created.
|
||||
// Else, the account in the DB will be updated.
|
||||
func (a *account) saveNew(db nosql.DB) error {
|
||||
kid, err := keyToID(a.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kidB := []byte(kid)
|
||||
|
||||
// Set the jwkID -> acme account ID index
|
||||
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
|
||||
default:
|
||||
if err = a.save(db, nil); err != nil {
|
||||
db.Del(accountByKeyIDTable, kidB)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *account) save(db nosql.DB, old *account) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(*a)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling new account object")
|
||||
}
|
||||
// Set the Account
|
||||
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing account"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing account; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// update updates the acme account object stored in the database if,
|
||||
// and only if, the account has not changed since the last read.
|
||||
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
|
||||
b := *a
|
||||
b.Contact = contact
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// deactivate deactivates the acme account.
|
||||
func (a *account) deactivate(db nosql.DB) (*account, error) {
|
||||
b := *a
|
||||
b.Status = StatusDeactivated
|
||||
b.Deactivated = clock.Now()
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// getAccountByID retrieves the account with the given ID.
|
||||
func getAccountByID(db nosql.DB, id string) (*account, error) {
|
||||
ab, err := db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
|
||||
}
|
||||
|
||||
a := new(account)
|
||||
if err = json.Unmarshal(ab, a); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// getAccountByKeyID retrieves Id associated with the given Kid.
|
||||
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
|
||||
id, err := db.Get(accountByKeyIDTable, []byte(kid))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
|
||||
}
|
||||
return getAccountByID(db, string(id))
|
||||
}
|
||||
|
||||
// getOrderIDsByAccount retrieves a list of Order IDs that were created by the
|
||||
// account.
|
||||
func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) {
|
||||
b, err := db.Get(ordersByAccountIDTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return []string{}, nil
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id))
|
||||
}
|
||||
var orderIDs []string
|
||||
if err := json.Unmarshal(b, &orderIDs); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id))
|
||||
}
|
||||
return orderIDs, nil
|
||||
}
|
@ -0,0 +1,844 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func newProv() provisioner.Interface {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func newAcc() (*account, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newAccount(mockdb, AccountOptions{
|
||||
Key: jwk, Contact: []string{"foo", "bar"},
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAccountByID(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountByKeyID(t *testing.T) {
|
||||
type test struct {
|
||||
kid string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/kid-not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading key-account index: force")),
|
||||
}
|
||||
},
|
||||
"fail/getAccount-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
count++
|
||||
return []byte("bar"), nil
|
||||
}
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading account bar: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
kid: acc.Key.KeyID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
var ret []byte
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(acc.Key.KeyID))
|
||||
ret = []byte(acc.ID)
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
ret = b
|
||||
}
|
||||
count++
|
||||
return ret, nil
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountIDsByAccount(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
res []string
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
res: []string{},
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
oids := []string{"foo", "bar", "baz"}
|
||||
b, err := json.Marshal(oids)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
res: oids,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.res, oids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{acc: acc}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAccount, err := tc.acc.toACME(nil, dir, prov)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
|
||||
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
|
||||
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acmeAccount.Orders,
|
||||
fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSave(t *testing.T) {
|
||||
type test struct {
|
||||
acc, old *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAcc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAcc)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: oldAcc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSaveNew(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/keyToID-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc.Key.Key = "foo"
|
||||
return test{
|
||||
acc: acc,
|
||||
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
|
||||
}
|
||||
},
|
||||
"fail/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"fail/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
MDel: func(bucket, key []byte) error {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
return nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.saveNew(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUpdate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
contact []string
|
||||
db nosql.DB
|
||||
res []byte
|
||||
err *Error
|
||||
}
|
||||
contact := []string{"foo", "bar"}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
res: b,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.update(tc.db, tc.contact)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, b, tc.res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeactivate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.deactivate(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.acc.ID)
|
||||
assert.Equals(t, acc.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acc.Status, StatusDeactivated)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acc.Created, tc.acc.Created)
|
||||
|
||||
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAccount(t *testing.T) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(jwk)
|
||||
assert.FatalError(t, err)
|
||||
ops := AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
type test struct {
|
||||
ops AccountOptions
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/store-error": func(t *testing.T) test {
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
count := 0
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
*id = string(key)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acc, err := newAccount(tc.db, tc.ops)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, *tc.id)
|
||||
assert.Equals(t, acc.Status, StatusValid)
|
||||
assert.Equals(t, acc.Contact, ops.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
|
||||
|
||||
assert.True(t, acc.Deactivated.IsZero())
|
||||
|
||||
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,213 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// NewAccountRequest represents the payload for a new account request.
|
||||
type NewAccountRequest struct {
|
||||
Contact []string `json:"contact"`
|
||||
OnlyReturnExisting bool `json:"onlyReturnExisting"`
|
||||
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
|
||||
}
|
||||
|
||||
func validateContacts(cs []string) error {
|
||||
for _, c := range cs {
|
||||
if len(c) == 0 {
|
||||
return acme.MalformedErr(errors.New("contact cannot be empty string"))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a new-account request body.
|
||||
func (n *NewAccountRequest) Validate() error {
|
||||
if n.OnlyReturnExisting && len(n.Contact) > 0 {
|
||||
return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
|
||||
}
|
||||
return validateContacts(n.Contact)
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents an update-account request.
|
||||
type UpdateAccountRequest struct {
|
||||
Contact []string `json:"contact"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// IsDeactivateRequest returns true if the update request is a deactivation
|
||||
// request, false otherwise.
|
||||
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
|
||||
return u.Status == acme.StatusDeactivated
|
||||
}
|
||||
|
||||
// Validate validates a update-account request body.
|
||||
func (u *UpdateAccountRequest) Validate() error {
|
||||
switch {
|
||||
case len(u.Status) > 0 && len(u.Contact) > 0:
|
||||
return acme.MalformedErr(errors.New("incompatible input; contact and " +
|
||||
"status updates are mutually exclusive"))
|
||||
case len(u.Contact) > 0:
|
||||
if err := validateContacts(u.Contact); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case len(u.Status) > 0:
|
||||
if u.Status != acme.StatusDeactivated {
|
||||
return acme.MalformedErr(errors.Errorf("cannot update account "+
|
||||
"status to %s, only deactivated", u.Status))
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return acme.MalformedErr(errors.Errorf("empty update request"))
|
||||
}
|
||||
}
|
||||
|
||||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var nar NewAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||
"failed to unmarshal new-account request payload")))
|
||||
return
|
||||
}
|
||||
if err := nar.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpStatus := http.StatusCreated
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
acmeErr, ok := err.(*acme.Error)
|
||||
if !ok || acmeErr.Status != http.StatusNotFound {
|
||||
// Something went wrong ...
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Account does not exist //
|
||||
if nar.OnlyReturnExisting {
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
}); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Account exists //
|
||||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink,
|
||||
acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||
api.JSONStatus(w, acc, httpStatus)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUpdateAccount is the api for updating an ACME account.
|
||||
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !payload.isPostAsGet {
|
||||
var uar UpdateAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
|
||||
return
|
||||
}
|
||||
if err := uar.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if uar.IsDeactivateRequest() {
|
||||
acc, err = h.Auth.DeactivateAccount(prov, acc.GetID())
|
||||
} else {
|
||||
acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact)
|
||||
}
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||
api.JSON(w, acc)
|
||||
return
|
||||
}
|
||||
|
||||
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
"orders": oids,
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
|
||||
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
accID := chi.URLParam(r, "accID")
|
||||
if acc.ID != accID {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
|
||||
return
|
||||
}
|
||||
orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
api.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
return
|
||||
}
|
@ -0,0 +1,790 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func newProv() provisioner.Interface {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestNewAccountRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
nar *NewAccountRequest
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/incompatible-input": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
Contact: []string{"foo", "bar"},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")),
|
||||
}
|
||||
},
|
||||
"fail/bad-contact": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/onlyReturnExisting": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.nar.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
uar *UpdateAccountRequest
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/incompatible-input": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
Status: "foo",
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("incompatible input; " +
|
||||
"contact and status updates are mutually exclusive")),
|
||||
}
|
||||
},
|
||||
"fail/bad-contact": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"fail/bad-status": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Status: "foo",
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("cannot update account " +
|
||||
"status to foo, only deactivated")),
|
||||
}
|
||||
},
|
||||
"ok/contact": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/status": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Status: "deactivated",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.uar.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetOrdersByAccount(t *testing.T) {
|
||||
oids := []string{
|
||||
"https://ca.smallstep.com/acme/order/foo",
|
||||
"https://ca.smallstep.com/acme/order/bar",
|
||||
}
|
||||
accID := "account-id"
|
||||
prov := newProv()
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("accID", accID)
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "foo"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 401,
|
||||
problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")),
|
||||
}
|
||||
},
|
||||
"fail/getOrdersByAccount-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, acc.ID)
|
||||
return oids, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrdersByAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(oids)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerNewAccount(t *testing.T) {
|
||||
accID := "accountID"
|
||||
acc := acme.Account{
|
||||
ID: accID,
|
||||
Status: "valid",
|
||||
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||
}
|
||||
prov := newProv()
|
||||
|
||||
url := "https://ca.smallstep.com/acme/new-account"
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"fail/no-existing-account": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-jwk": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-jwk": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/NewAccount-error": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.Contact, nar.Contact)
|
||||
assert.Equals(t, ops.Key, jwk)
|
||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok/new-account": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.Contact, nar.Contact)
|
||||
assert.Equals(t, ops.Key, jwk)
|
||||
return &acc, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
"ok/return-existing": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||
accID := "accountID"
|
||||
acc := acme.Account{
|
||||
ID: accID,
|
||||
Status: "valid",
|
||||
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||
}
|
||||
prov := newProv()
|
||||
|
||||
// Request with chi context
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"fail/Deactivate-error": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Status: "deactivated",
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"fail/UpdateAccount-error": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
assert.Equals(t, contacts, uar.Contact)
|
||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok/deactivate": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Status: "deactivated",
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
return &acc, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/new-account": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
assert.Equals(t, contacts, uar.Contact)
|
||||
return &acc, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/post-as-get": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetUpdateAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,214 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func link(url, typ string) string {
|
||||
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
accContextKey = contextKey("acc")
|
||||
jwsContextKey = contextKey("jws")
|
||||
jwkContextKey = contextKey("jwk")
|
||||
payloadContextKey = contextKey("payload")
|
||||
provisionerContextKey = contextKey("provisioner")
|
||||
)
|
||||
|
||||
type payloadInfo struct {
|
||||
value []byte
|
||||
isPostAsGet bool
|
||||
isEmptyJSON bool
|
||||
}
|
||||
|
||||
func accountFromContext(r *http.Request) (*acme.Account, error) {
|
||||
val, ok := r.Context().Value(accContextKey).(*acme.Account)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.AccountDoesNotExistErr(nil)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) {
|
||||
val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) {
|
||||
val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func payloadFromContext(r *http.Request) (*payloadInfo, error) {
|
||||
val, ok := r.Context().Value(payloadContextKey).(*payloadInfo)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func provisionerFromContext(r *http.Request) (provisioner.Interface, error) {
|
||||
val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// New returns a new ACME API router.
|
||||
func New(acmeAuth acme.Interface) api.RouterHandler {
|
||||
return &Handler{acmeAuth}
|
||||
}
|
||||
|
||||
// Handler is the ACME request handler.
|
||||
type Handler struct {
|
||||
Auth acme.Interface
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
getLink := h.Auth.GetLink
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||
}
|
||||
|
||||
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount))
|
||||
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
|
||||
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder))
|
||||
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
|
||||
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
|
||||
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||
}
|
||||
|
||||
// GetNonce just sets the right header since a Nonce is added to each response
|
||||
// by middleware by default.
|
||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetDirectory is the ACME resource for returning a directory configuration
|
||||
// for client configuration.
|
||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
dir := h.Auth.GetDirectory(prov)
|
||||
api.JSON(w, dir)
|
||||
return
|
||||
}
|
||||
|
||||
// GetAuthz ACME api for retrieving an Authz.
|
||||
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID()))
|
||||
api.JSON(w, authz)
|
||||
return
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
// Just verify that the payload was set, since we're not strictly adhering
|
||||
// to ACME V2 spec for reasons specified below.
|
||||
_, err = payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// NOTE: We should be checking that the request is either a POST-as-GET, or
|
||||
// that the payload is an empty JSON block ({}). However, older ACME clients
|
||||
// still send a vestigial body (rather than an empty JSON block) and
|
||||
// strict enforcement would render these clients broken. For the time being
|
||||
// we'll just ignore the body.
|
||||
var (
|
||||
ch *acme.Challenge
|
||||
chID = chi.URLParam(r, "chID")
|
||||
)
|
||||
ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
getLink := h.Auth.GetLink
|
||||
w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up"))
|
||||
w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID()))
|
||||
api.JSON(w, ch)
|
||||
return
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
certID := chi.URLParam(r, "certID")
|
||||
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
||||
w.Write(certBytes)
|
||||
return
|
||||
}
|
@ -0,0 +1,771 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
type mockAcmeAuthority struct {
|
||||
deactivateAccount func(provisioner.Interface, string) (*acme.Account, error)
|
||||
finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error)
|
||||
getAccount func(p provisioner.Interface, id string) (*acme.Account, error)
|
||||
getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error)
|
||||
getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error)
|
||||
getCertificate func(accID string, id string) ([]byte, error)
|
||||
getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error)
|
||||
getDirectory func(provisioner.Interface) *acme.Directory
|
||||
getLink func(acme.Link, string, bool, ...string) string
|
||||
getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error)
|
||||
getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error)
|
||||
loadProvisionerByID func(string) (provisioner.Interface, error)
|
||||
newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error)
|
||||
newNonce func() (string, error)
|
||||
newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error)
|
||||
updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error)
|
||||
useNonce func(string) error
|
||||
validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error)
|
||||
ret1 interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
if m.deactivateAccount != nil {
|
||||
return m.deactivateAccount(p, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
|
||||
if m.finalizeOrder != nil {
|
||||
return m.finalizeOrder(p, accID, id, csr)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Order), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
if m.getAccount != nil {
|
||||
return m.getAccount(p, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) {
|
||||
if m.getAccountByKey != nil {
|
||||
return m.getAccountByKey(p, jwk)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
|
||||
if m.getAuthz != nil {
|
||||
return m.getAuthz(p, accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Authz), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) {
|
||||
if m.getCertificate != nil {
|
||||
return m.getCertificate(accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.([]byte), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) {
|
||||
if m.getChallenge != nil {
|
||||
return m.getChallenge(p, accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Challenge), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory {
|
||||
if m.getDirectory != nil {
|
||||
return m.getDirectory(p)
|
||||
}
|
||||
return m.ret1.(*acme.Directory)
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
if m.getLink != nil {
|
||||
return m.getLink(typ, provID, abs, in...)
|
||||
}
|
||||
return m.ret1.(string)
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) {
|
||||
if m.getOrder != nil {
|
||||
return m.getOrder(p, accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Order), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||
if m.getOrdersByAccount != nil {
|
||||
return m.getOrdersByAccount(p, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.([]string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByID != nil {
|
||||
return m.loadProvisionerByID(provID)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||
if m.newAccount != nil {
|
||||
return m.newAccount(p, ops)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) NewNonce() (string, error) {
|
||||
if m.newNonce != nil {
|
||||
return m.newNonce()
|
||||
} else if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.ret1.(string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
if m.newOrder != nil {
|
||||
return m.newOrder(p, ops)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Order), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) {
|
||||
if m.updateAccount != nil {
|
||||
return m.updateAccount(p, id, contact)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) UseNonce(nonce string) error {
|
||||
if m.useNonce != nil {
|
||||
return m.useNonce(nonce)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
||||
switch {
|
||||
case m.validateChallenge != nil:
|
||||
return m.validateChallenge(p, accID, id, jwk)
|
||||
case m.err != nil:
|
||||
return nil, m.err
|
||||
default:
|
||||
return m.ret1.(*acme.Challenge), m.err
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
}{
|
||||
{"GET", 204},
|
||||
{"HEAD", 200},
|
||||
}
|
||||
|
||||
// Request with chi context
|
||||
req := httptest.NewRequest("GET", "http://ca.smallstep.com/nonce", nil)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(nil).(*Handler)
|
||||
w := httptest.NewRecorder()
|
||||
req.Method = tt.name
|
||||
h.GetNonce(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("Handler.GetNonce StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetDirectory(t *testing.T) {
|
||||
auth := acme.NewAuthority(nil, "ca.smallstep.com", "acme", nil)
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov))
|
||||
|
||||
expDir := acme.Directory{
|
||||
NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)),
|
||||
NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)),
|
||||
NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)),
|
||||
RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)),
|
||||
KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)),
|
||||
}
|
||||
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetDirectory(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
var dir acme.Directory
|
||||
json.Unmarshal(bytes.TrimSpace(body), &dir)
|
||||
assert.Equals(t, dir, expDir)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetAuthz(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
az := acme.Authz{
|
||||
ID: "authzID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "example.com",
|
||||
},
|
||||
Status: "pending",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
Wildcard: false,
|
||||
Challenges: []*acme.Challenge{
|
||||
{
|
||||
Type: "http-01",
|
||||
Status: "pending",
|
||||
Token: "tok2",
|
||||
URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
|
||||
ID: "chHTTP01ID",
|
||||
AuthzID: "authzID",
|
||||
},
|
||||
{
|
||||
Type: "dns-01",
|
||||
Status: "pending",
|
||||
Token: "tok2",
|
||||
URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
|
||||
ID: "chDNSID",
|
||||
AuthzID: "authzID",
|
||||
},
|
||||
},
|
||||
}
|
||||
prov := newProv()
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("authzID", az.ID)
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s",
|
||||
acme.URLSafeProvisionerName(prov), az.ID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/getAuthz-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, az.ID)
|
||||
return &az, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.Equals(t, typ, acme.AuthzLink)
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{az.ID})
|
||||
return url
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAuthz(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
//var gotAz acme.Authz
|
||||
//assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &gotAz))
|
||||
expB, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetCertificate(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
certBytes := append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}), pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: inter.Raw,
|
||||
})...)
|
||||
certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})...)
|
||||
certID := "certID"
|
||||
|
||||
prov := newProv()
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("certID", certID)
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s",
|
||||
acme.URLSafeProvisionerName(prov), certID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/getCertificate-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getCertificate: func(accID, id string) ([]byte, error) {
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, certID)
|
||||
return certBytes, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetCertificate(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ch() acme.Challenge {
|
||||
return acme.Challenge{
|
||||
Type: "http-01",
|
||||
Status: "pending",
|
||||
Token: "tok2",
|
||||
URL: "https://ca.smallstep.com/acme/challenge/chID",
|
||||
ID: "chID",
|
||||
AuthzID: "authzID",
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetChallenge(t *testing.T) {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("chID", "chID")
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID")
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
ch acme.Challenge
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.UnauthorizedErr(nil),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 401,
|
||||
problem: acme.UnauthorizedErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/get-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.UnauthorizedErr(nil),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 401,
|
||||
problem: acme.UnauthorizedErr(nil),
|
||||
}
|
||||
},
|
||||
"ok/validate-challenge": func(t *testing.T) test {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accID", Key: key}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ch := ch()
|
||||
ch.Status = "valid"
|
||||
ch.Validated = time.Now().UTC().Format(time.RFC3339)
|
||||
count := 0
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, ch.ID)
|
||||
assert.Equals(t, jwk.KeyID, key.KeyID)
|
||||
return &ch, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
var ret string
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, typ, acme.AuthzLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{ch.AuthzID})
|
||||
ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID)
|
||||
case 1:
|
||||
assert.Equals(t, typ, acme.ChallengeLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{ch.ID})
|
||||
ret = url
|
||||
}
|
||||
count++
|
||||
return ret
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
ch: ch,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetChallenge(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(tc.ch)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<https://ca.smallstep.com/acme/authz/%s>;rel=\"up\"", tc.ch.AuthzID)})
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,377 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
func logNonce(w http.ResponseWriter, nonce string) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
"nonce": nonce,
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce, err := h.Auth.NewNonce()
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Replay-Nonce", nonce)
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
logNonce(w, nonce)
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||
// directory index url.
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index"))
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// verifyContentType is a middleware that verifies that content type is
|
||||
// application/jose+json.
|
||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
ct := r.Header.Get("Content-Type")
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
} else {
|
||||
// By default every request should have content-type applictaion/jose+json.
|
||||
expected = []string{"application/jose+json"}
|
||||
}
|
||||
for _, e := range expected {
|
||||
if ct == e {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf(
|
||||
"expected content-type to be in %s, but got %s", expected, ct)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
|
||||
return
|
||||
}
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// validateJWS checks the request body for to verify that it meets ACME
|
||||
// requirements for a JWS.
|
||||
//
|
||||
// The JWS MUST NOT have multiple signatures
|
||||
// The JWS Unencoded Payload Option [RFC7797] MUST NOT be used
|
||||
// The JWS Unprotected Header [RFC7515] MUST NOT be used
|
||||
// The JWS Payload MUST NOT be detached
|
||||
// The JWS Protected Header MUST include the following fields:
|
||||
// * “alg” (Algorithm)
|
||||
// * This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// * “nonce” (defined in Section 6.5)
|
||||
// * “url” (defined in Section 6.4)
|
||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
|
||||
return
|
||||
}
|
||||
|
||||
sig := jws.Signatures[0]
|
||||
uh := sig.Unprotected
|
||||
if len(uh.KeyID) > 0 ||
|
||||
uh.JSONWebKey != nil ||
|
||||
len(uh.Algorithm) > 0 ||
|
||||
len(uh.Nonce) > 0 ||
|
||||
len(uh.ExtraHeaders) > 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
|
||||
return
|
||||
}
|
||||
hdr := sig.Protected
|
||||
switch hdr.Algorithm {
|
||||
case jose.RS256, jose.RS384, jose.RS512:
|
||||
if hdr.JSONWebKey != nil {
|
||||
switch k := hdr.JSONWebKey.Key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if k.Size() < keys.MinRSAKeyBytes {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
|
||||
"keys must be at least %d bits (%d bytes) in size",
|
||||
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)))
|
||||
return
|
||||
}
|
||||
default:
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
|
||||
return
|
||||
}
|
||||
}
|
||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||
// we good
|
||||
default:
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
|
||||
return
|
||||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the JWS url matches the requested url.
|
||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||
if !ok {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
|
||||
return
|
||||
}
|
||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||
if jwsURL != reqURL.String() {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
|
||||
return
|
||||
}
|
||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||
// in the context. Make sure to parse and validate the JWS before running this
|
||||
// middleware.
|
||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||
if jwk == nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
|
||||
return
|
||||
}
|
||||
if !jwk.Valid() {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
acc, err := h.Auth.GetAccountByKey(prov, jwk)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
// For NewAccount requests ...
|
||||
break
|
||||
case err != nil:
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responsds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
name := chi.URLParam(r, "provisionerID")
|
||||
provID, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
|
||||
return
|
||||
}
|
||||
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if p.GetType() != provisioner.TypeACME {
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, p)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||
// kid parameter of the signed payload.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
|
||||
"required prefix; expected %s, but got %s", kidPrefix, kid)))
|
||||
return
|
||||
}
|
||||
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.Auth.GetAccount(prov, accID)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
return
|
||||
case err != nil:
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
|
||||
return
|
||||
}
|
||||
payload, err := jws.Verify(jwk)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{
|
||||
value: payload,
|
||||
isPostAsGet: string(payload) == "",
|
||||
isEmptyJSON: string(payload) == "{}",
|
||||
})
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if !payload.isPostAsGet {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,164 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
)
|
||||
|
||||
// NewOrderRequest represents the body for a NewOrder request.
|
||||
type NewOrderRequest struct {
|
||||
Identifiers []acme.Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
}
|
||||
|
||||
// Validate validates a new-order request body.
|
||||
func (n *NewOrderRequest) Validate() error {
|
||||
if len(n.Identifiers) == 0 {
|
||||
return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
|
||||
}
|
||||
for _, id := range n.Identifiers {
|
||||
if id.Type != "dns" {
|
||||
return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinalizeRequest captures the body for a Finalize order request.
|
||||
type FinalizeRequest struct {
|
||||
CSR string `json:"csr"`
|
||||
csr *x509.CertificateRequest
|
||||
}
|
||||
|
||||
// Validate validates a finalize request body.
|
||||
func (f *FinalizeRequest) Validate() error {
|
||||
var err error
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||
if err != nil {
|
||||
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
|
||||
}
|
||||
f.csr, err = x509.ParseCertificateRequest(csrBytes)
|
||||
if err != nil {
|
||||
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
|
||||
}
|
||||
if err = f.csr.CheckSignature(); err != nil {
|
||||
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewOrder ACME api for creating a new order.
|
||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var nor NewOrderRequest
|
||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||
"failed to unmarshal new-order request payload")))
|
||||
return
|
||||
}
|
||||
if err := nor.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := h.Auth.NewOrder(prov, acme.OrderOptions{
|
||||
AccountID: acc.GetID(),
|
||||
Identifiers: nor.Identifiers,
|
||||
NotBefore: nor.NotBefore,
|
||||
NotAfter: nor.NotAfter,
|
||||
})
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||
api.JSONStatus(w, o, http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// GetOrder ACME api for retrieving an order.
|
||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
oid := chi.URLParam(r, "ordID")
|
||||
o, err := h.Auth.GetOrder(prov, acc.GetID(), oid)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||
api.JSON(w, o)
|
||||
return
|
||||
}
|
||||
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var fr FinalizeRequest
|
||||
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
|
||||
return
|
||||
}
|
||||
if err := fr.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
oid := chi.URLParam(r, "ordID")
|
||||
o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID))
|
||||
api.JSON(w, o)
|
||||
return
|
||||
}
|
@ -0,0 +1,757 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestNewOrderRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
nor *NewOrderRequest
|
||||
nbf, naf time.Time
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-identifiers": func(t *testing.T) test {
|
||||
return test{
|
||||
nor: &NewOrderRequest{},
|
||||
err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")),
|
||||
}
|
||||
},
|
||||
"fail/bad-identifier": func(t *testing.T) test {
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "foo", Value: "bar.com"},
|
||||
},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.nor.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if tc.nbf.IsZero() {
|
||||
assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute)))
|
||||
} else {
|
||||
assert.Equals(t, tc.nor.NotBefore, tc.nbf)
|
||||
}
|
||||
if tc.naf.IsZero() {
|
||||
assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour)))
|
||||
assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute)))
|
||||
} else {
|
||||
assert.Equals(t, tc.nor.NotAfter, tc.naf)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestValidate(t *testing.T) {
|
||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||
assert.FatalError(t, err)
|
||||
csr, ok := _csr.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
type test struct {
|
||||
fr *FinalizeRequest
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/parse-csr-error": func(t *testing.T) test {
|
||||
return test{
|
||||
fr: &FinalizeRequest{},
|
||||
err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")),
|
||||
}
|
||||
},
|
||||
"fail/invalid-csr-signature": func(t *testing.T) test {
|
||||
b, err := pemutil.Read("../../authority/testdata/certs/badsig.csr")
|
||||
assert.FatalError(t, err)
|
||||
c, ok := b.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
return test{
|
||||
fr: &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(c.Raw),
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
fr: &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.fr.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.fr.csr.Raw, csr.Raw)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetOrder(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
nbf := time.Now().UTC()
|
||||
naf := time.Now().UTC().Add(24 * time.Hour)
|
||||
o := acme.Order{
|
||||
ID: "orderID",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
NotBefore: nbf.Format(time.RFC3339),
|
||||
NotAfter: naf.Format(time.RFC3339),
|
||||
Identifiers: []acme.Identifier{
|
||||
{
|
||||
Type: "dns",
|
||||
Value: "example.com",
|
||||
},
|
||||
{
|
||||
Type: "dns",
|
||||
Value: "*.smallstep.com",
|
||||
},
|
||||
},
|
||||
Status: "pending",
|
||||
Authorizations: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("ordID", o.ID)
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/getOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, o.ID)
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return url
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(o)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerNewOrder(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
nbf := time.Now().UTC().Add(5 * time.Hour)
|
||||
naf := nbf.Add(17 * time.Hour)
|
||||
o := acme.Order{
|
||||
ID: "orderID",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
NotBefore: nbf.Format(time.RFC3339),
|
||||
NotAfter: naf.Format(time.RFC3339),
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
Status: "pending",
|
||||
Authorizations: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order",
|
||||
acme.URLSafeProvisionerName(prov))
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")),
|
||||
}
|
||||
},
|
||||
"fail/NewOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.AccountID, acc.ID)
|
||||
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||
return nil, acme.MalformedErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
NotBefore: nbf,
|
||||
NotAfter: naf,
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.AccountID, acc.ID)
|
||||
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||
assert.Equals(t, ops.NotBefore, nbf)
|
||||
assert.Equals(t, ops.NotAfter, naf)
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
"ok/default-naf-nbf": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.AccountID, acc.ID)
|
||||
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||
|
||||
assert.True(t, ops.NotBefore.IsZero())
|
||||
assert.True(t, ops.NotAfter.IsZero())
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(o)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerFinalizeOrder(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
nbf := time.Now().UTC().Add(5 * time.Hour)
|
||||
naf := nbf.Add(17 * time.Hour)
|
||||
o := acme.Order{
|
||||
ID: "orderID",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
NotBefore: nbf.Format(time.RFC3339),
|
||||
NotAfter: naf.Format(time.RFC3339),
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
Status: "valid",
|
||||
Authorizations: []string{"foo", "bar"},
|
||||
Certificate: "https://ca.smallstep.com/acme/certificate/certID",
|
||||
}
|
||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||
assert.FatalError(t, err)
|
||||
csr, ok := _csr.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("ordID", o.ID)
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &FinalizeRequest{}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")),
|
||||
}
|
||||
},
|
||||
"fail/FinalizeOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, o.ID)
|
||||
assert.Equals(t, incsr.Raw, csr.Raw)
|
||||
return nil, acme.MalformedErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, o.ID)
|
||||
assert.Equals(t, incsr.Raw, csr.Raw)
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.FinalizeOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(o)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,263 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Interface is the acme authority interface.
|
||||
type Interface interface {
|
||||
DeactivateAccount(provisioner.Interface, string) (*Account, error)
|
||||
FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error)
|
||||
GetAccount(provisioner.Interface, string) (*Account, error)
|
||||
GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error)
|
||||
GetAuthz(provisioner.Interface, string, string) (*Authz, error)
|
||||
GetCertificate(string, string) ([]byte, error)
|
||||
GetDirectory(provisioner.Interface) *Directory
|
||||
GetLink(Link, string, bool, ...string) string
|
||||
GetOrder(provisioner.Interface, string, string) (*Order, error)
|
||||
GetOrdersByAccount(provisioner.Interface, string) ([]string, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
NewAccount(provisioner.Interface, AccountOptions) (*Account, error)
|
||||
NewNonce() (string, error)
|
||||
NewOrder(provisioner.Interface, OrderOptions) (*Order, error)
|
||||
UpdateAccount(provisioner.Interface, string, []string) (*Account, error)
|
||||
UseNonce(string) error
|
||||
ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error)
|
||||
}
|
||||
|
||||
// Authority is the layer that handles all ACME interactions.
|
||||
type Authority struct {
|
||||
db nosql.DB
|
||||
dir *directory
|
||||
signAuth SignAuthority
|
||||
}
|
||||
|
||||
// NewAuthority returns a new Authority that implements the ACME interface.
|
||||
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) *Authority {
|
||||
return &Authority{
|
||||
db: db, dir: newDirectory(dns, prefix), signAuth: signAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLink returns the requested link from the directory.
|
||||
func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string {
|
||||
return a.dir.getLink(typ, provID, abs, inputs...)
|
||||
}
|
||||
|
||||
// GetDirectory returns the ACME directory object.
|
||||
func (a *Authority) GetDirectory(p provisioner.Interface) *Directory {
|
||||
name := url.PathEscape(p.GetName())
|
||||
return &Directory{
|
||||
NewNonce: a.dir.getLink(NewNonceLink, name, true),
|
||||
NewAccount: a.dir.getLink(NewAccountLink, name, true),
|
||||
NewOrder: a.dir.getLink(NewOrderLink, name, true),
|
||||
RevokeCert: a.dir.getLink(RevokeCertLink, name, true),
|
||||
KeyChange: a.dir.getLink(KeyChangeLink, name, true),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadProvisionerByID calls out to the SignAuthority interface to load a
|
||||
// provisioner by ID.
|
||||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
return a.signAuth.LoadProvisionerByID(id)
|
||||
}
|
||||
|
||||
// NewNonce generates, stores, and returns a new ACME nonce.
|
||||
func (a *Authority) NewNonce() (string, error) {
|
||||
n, err := newNonce(a.db)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return n.ID, nil
|
||||
}
|
||||
|
||||
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
|
||||
func (a *Authority) UseNonce(nonce string) error {
|
||||
return useNonce(a.db, nonce)
|
||||
}
|
||||
|
||||
// NewAccount creates, stores, and returns a new ACME account.
|
||||
func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) {
|
||||
acc, err := newAccount(a.db, ao)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// UpdateAccount updates an ACME account.
|
||||
func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if acc, err = acc.update(a.db, contact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetAccount returns an ACME account.
|
||||
func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates an ACME account.
|
||||
func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if acc, err = acc.deactivate(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
func keyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
||||
// GetAccountByKey returns the ACME associated with the jwk id.
|
||||
func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) {
|
||||
kid, err := keyToID(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acc, err := getAccountByKeyID(a.db, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetOrder returns an ACME order.
|
||||
func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
if o, err = o.updateStatus(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetOrdersByAccount returns the list of order urls owned by the account.
|
||||
func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||
oids, err := getOrderIDsByAccount(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret = []string{}
|
||||
for _, oid := range oids {
|
||||
o, err := getOrder(a.db, oid)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if o.Status == StatusInvalid {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID))
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// NewOrder generates, stores, and returns a new ACME order.
|
||||
func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) {
|
||||
order, err := newOrder(a.db, ops)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating order")
|
||||
}
|
||||
return order.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and generate a new certificate.
|
||||
func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
o, err = o.finalize(a.db, csr, a.signAuth, p)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error finalizing order")
|
||||
}
|
||||
return o.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetAuthz retrieves and attempts to update the status on an ACME authz
|
||||
// before returning.
|
||||
func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) {
|
||||
authz, err := getAuthz(a.db, authzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != authz.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own authz"))
|
||||
}
|
||||
authz, err = authz.updateStatus(a.db)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error updating authz status")
|
||||
}
|
||||
return authz.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// ValidateChallenge attempts to validate the challenge.
|
||||
func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
|
||||
ch, err := getChallenge(a.db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != ch.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: time.Duration(30 * time.Second),
|
||||
}
|
||||
ch, err = ch.validate(a.db, jwk, validateOptions{
|
||||
httpGet: client.Get,
|
||||
lookupTxt: net.LookupTXT,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error attempting challenge validation")
|
||||
}
|
||||
return ch.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetCertificate retrieves the Certificate by ID.
|
||||
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
|
||||
cert, err := getCert(a.db, certID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != cert.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
|
||||
}
|
||||
return cert.toACME(a.db, a.dir)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,344 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
var defaultExpiryDuration = time.Hour * 24
|
||||
|
||||
// Authz is a subset of the Authz type containing only those attributes
|
||||
// required for responses in the ACME protocol.
|
||||
type Authz struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Authz) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Authz ID.
|
||||
func (a *Authz) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// authz is the interface that the various authz types must implement.
|
||||
type authz interface {
|
||||
save(nosql.DB, authz) error
|
||||
clone() *baseAuthz
|
||||
getID() string
|
||||
getAccountID() string
|
||||
getType() string
|
||||
getIdentifier() Identifier
|
||||
getStatus() string
|
||||
getExpiry() time.Time
|
||||
getWildcard() bool
|
||||
getChallenges() []string
|
||||
getCreated() time.Time
|
||||
updateStatus(db nosql.DB) (authz, error)
|
||||
toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error)
|
||||
}
|
||||
|
||||
// baseAuthz is the base authz type that others build from.
|
||||
type baseAuthz struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires time.Time `json:"expires"`
|
||||
Challenges []string `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
Created time.Time `json:"created"`
|
||||
Error *Error `json:"error"`
|
||||
}
|
||||
|
||||
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
ba := &baseAuthz{
|
||||
ID: id,
|
||||
AccountID: accID,
|
||||
Status: StatusPending,
|
||||
Created: now,
|
||||
Expires: now.Add(defaultExpiryDuration),
|
||||
Identifier: identifier,
|
||||
}
|
||||
|
||||
if strings.HasPrefix(identifier.Value, "*.") {
|
||||
ba.Wildcard = true
|
||||
ba.Identifier = Identifier{
|
||||
Value: strings.TrimPrefix(identifier.Value, "*."),
|
||||
Type: identifier.Type,
|
||||
}
|
||||
}
|
||||
|
||||
return ba, nil
|
||||
}
|
||||
|
||||
// getID returns the ID of the authz.
|
||||
func (ba *baseAuthz) getID() string {
|
||||
return ba.ID
|
||||
}
|
||||
|
||||
// getAccountID returns the Account ID that created the authz.
|
||||
func (ba *baseAuthz) getAccountID() string {
|
||||
return ba.AccountID
|
||||
}
|
||||
|
||||
// getType returns the type of the authz.
|
||||
func (ba *baseAuthz) getType() string {
|
||||
return ba.Identifier.Type
|
||||
}
|
||||
|
||||
// getIdentifier returns the identifier for the authz.
|
||||
func (ba *baseAuthz) getIdentifier() Identifier {
|
||||
return ba.Identifier
|
||||
}
|
||||
|
||||
// getStatus returns the status of the authz.
|
||||
func (ba *baseAuthz) getStatus() string {
|
||||
return ba.Status
|
||||
}
|
||||
|
||||
// getWildcard returns true if the authz identifier has a '*', false otherwise.
|
||||
func (ba *baseAuthz) getWildcard() bool {
|
||||
return ba.Wildcard
|
||||
}
|
||||
|
||||
// getChallenges returns the authz challenge IDs.
|
||||
func (ba *baseAuthz) getChallenges() []string {
|
||||
return ba.Challenges
|
||||
}
|
||||
|
||||
// getExpiry returns the expiration time of the authz.
|
||||
func (ba *baseAuthz) getExpiry() time.Time {
|
||||
return ba.Expires
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the authz.
|
||||
func (ba *baseAuthz) getCreated() time.Time {
|
||||
return ba.Created
|
||||
}
|
||||
|
||||
// toACME converts the internal Authz type into the public acmeAuthz type for
|
||||
// presentation in the ACME protocol.
|
||||
func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) {
|
||||
var chs = make([]*Challenge, len(ba.Challenges))
|
||||
for i, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chs[i], err = ch.toACME(db, dir, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Authz{
|
||||
Identifier: ba.Identifier,
|
||||
Status: ba.getStatus(),
|
||||
Challenges: chs,
|
||||
Wildcard: ba.getWildcard(),
|
||||
Expires: ba.Expires.Format(time.RFC3339),
|
||||
ID: ba.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
|
||||
var (
|
||||
err error
|
||||
oldB, newB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
|
||||
}
|
||||
}
|
||||
if newB, err = json.Marshal(ba); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing authz; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) clone() *baseAuthz {
|
||||
u := *ba
|
||||
return &u
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) storeAndReturnError(db nosql.DB, err *Error) error {
|
||||
clone := ba.clone()
|
||||
clone.Error = err
|
||||
clone.save(db, ba)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) parent() authz {
|
||||
return &dnsAuthz{ba}
|
||||
}
|
||||
|
||||
// updateStatus attempts to update the status on a baseAuthz and stores the
|
||||
// updating object if necessary.
|
||||
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
|
||||
newAuthz := ba.clone()
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch ba.Status {
|
||||
case StatusInvalid:
|
||||
return ba.parent(), nil
|
||||
case StatusValid:
|
||||
return ba.parent(), nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(ba.Expires) {
|
||||
newAuthz.Status = StatusInvalid
|
||||
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var isValid = false
|
||||
for _, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return ba, err
|
||||
}
|
||||
if ch.getStatus() == StatusValid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return ba.parent(), nil
|
||||
}
|
||||
newAuthz.Status = StatusValid
|
||||
newAuthz.Error = nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
|
||||
}
|
||||
|
||||
if err := newAuthz.save(db, ba); err != nil {
|
||||
return ba, err
|
||||
}
|
||||
return newAuthz.parent(), nil
|
||||
}
|
||||
|
||||
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
|
||||
func unmarshalAuthz(data []byte) (authz, error) {
|
||||
var getType struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &getType); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
|
||||
}
|
||||
|
||||
switch getType.Identifier.Type {
|
||||
case "dns":
|
||||
var ba baseAuthz
|
||||
if err := json.Unmarshal(data, &ba); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
|
||||
}
|
||||
return &dnsAuthz{&ba}, nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
|
||||
getType.Identifier.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// dnsAuthz represents a dns acme authorization.
|
||||
type dnsAuthz struct {
|
||||
*baseAuthz
|
||||
}
|
||||
|
||||
// newAuthz returns a new acme authorization object based on the identifier
|
||||
// type.
|
||||
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
|
||||
switch identifier.Type {
|
||||
case "dns":
|
||||
a, err = newDNSAuthz(db, accID, identifier)
|
||||
default:
|
||||
err = MalformedErr(errors.Errorf("unexpected authz type %s",
|
||||
identifier.Type))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// newDNSAuthz returns a new dns acme authorization object.
|
||||
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
|
||||
ba, err := newBaseAuthz(accID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ba.Challenges = []string{}
|
||||
if !ba.Wildcard {
|
||||
// http challenges are only permitted if the DNS is not a wildcard dns.
|
||||
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: ba.Identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating http challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch1.getID())
|
||||
}
|
||||
ch2, err := newDNS01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating dns challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch2.getID())
|
||||
|
||||
da := &dnsAuthz{ba}
|
||||
if err := da.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
// getAuthz retrieves and unmarshals an ACME authz type from the database.
|
||||
func getAuthz(db nosql.DB, id string) (authz, error) {
|
||||
b, err := db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
|
||||
}
|
||||
az, err := unmarshalAuthz(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return az, nil
|
||||
}
|
@ -0,0 +1,809 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func newAz() (authz, error) {
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
return newAuthz(mockdb, "1234", Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthz(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
az authz
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzClone(t *testing.T) {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
clone := az.clone()
|
||||
|
||||
assert.Equals(t, clone.getID(), az.getID())
|
||||
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, clone.getStatus(), az.getStatus())
|
||||
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, clone.getCreated(), az.getCreated())
|
||||
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
||||
|
||||
clone.Status = StatusValid
|
||||
|
||||
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
||||
}
|
||||
|
||||
func TestNewAuthz(t *testing.T) {
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
accID := "1234"
|
||||
type test struct {
|
||||
iden Identifier
|
||||
db nosql.DB
|
||||
err *Error
|
||||
resChs *([]string)
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
||||
err: MalformedErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"fail/new-http-chall-error": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/new-dns-chall-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/save-authz-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), false)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
"ok/wildcard": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
||||
return test{
|
||||
iden: _iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), true)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
// Verify that we only have 1 challenge instead of 2.
|
||||
assert.True(t, len(*chs) == 1)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
az, err := newAuthz(tc.db, accID, tc.iden)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getType(), "dns")
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
|
||||
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
||||
|
||||
if strings.HasPrefix(tc.iden.Value, "*.") {
|
||||
assert.True(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
||||
} else {
|
||||
assert.False(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
||||
}
|
||||
|
||||
assert.True(t, az.getID() != "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
|
||||
var (
|
||||
ch1, ch2 challenge
|
||||
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
ch1, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/getChallenge1-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"fail/getChallenge2-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 1 {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAz, err := az.toACME(tc.db, dir, prov)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAz.ID, az.getID())
|
||||
assert.Equals(t, acmeAz.Identifier, iden)
|
||||
assert.Equals(t, acmeAz.Status, StatusPending)
|
||||
|
||||
acmeCh1, err := ch1.toACME(nil, dir, prov)
|
||||
assert.FatalError(t, err)
|
||||
acmeCh2, err := ch2.toACME(nil, dir, prov)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
||||
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
||||
|
||||
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzSave(t *testing.T) {
|
||||
type test struct {
|
||||
az, old authz
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAz, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAz)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: oldAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.az.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUnmarshal(t *testing.T) {
|
||||
type test struct {
|
||||
az authz
|
||||
azb []byte
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/nil": func(t *testing.T) test {
|
||||
return test{
|
||||
azb: nil,
|
||||
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
azb: b,
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok/dns": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
azb: b,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUpdateStatus(t *testing.T) {
|
||||
type test struct {
|
||||
az, res authz
|
||||
err *Error
|
||||
db nosql.DB
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/already-invalid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/already-valid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusValid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/unexpected-status": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusReady
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok/expired": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Error = MalformedErr(errors.New("authz has expired"))
|
||||
clone.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/get-challenge-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok/valid": func(t *testing.T) test {
|
||||
var (
|
||||
ch2 challenge
|
||||
ch1Bytes = &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Error = MalformedErr(nil)
|
||||
|
||||
_ch, ok := ch2.(*dns01Challenge)
|
||||
assert.Fatal(t, ok)
|
||||
_ch.baseChallenge.Status = StatusValid
|
||||
chb, err := json.Marshal(ch2)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Status = StatusValid
|
||||
clone.Error = nil
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return chb, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/still-pending": func(t *testing.T) test {
|
||||
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
az, err := tc.az.updateStatus(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
expB, err := json.Marshal(tc.res)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expB, b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type certificate struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
AccountID string `json:"accountID"`
|
||||
OrderID string `json:"orderID"`
|
||||
Leaf []byte `json:"leaf"`
|
||||
Intermediates []byte `json:"intermediates"`
|
||||
}
|
||||
|
||||
// CertOptions options with which to create and store a cert object.
|
||||
type CertOptions struct {
|
||||
AccountID string
|
||||
OrderID string
|
||||
Leaf *x509.Certificate
|
||||
Intermediates []*x509.Certificate
|
||||
}
|
||||
|
||||
func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
cert := &certificate{
|
||||
ID: id,
|
||||
AccountID: ops.AccountID,
|
||||
OrderID: ops.OrderID,
|
||||
Leaf: leaf,
|
||||
Intermediates: intermediates,
|
||||
Created: time.Now().UTC(),
|
||||
}
|
||||
certB, err := json.Marshal(cert)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing certificate; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
|
||||
return append(c.Leaf, c.Intermediates...), nil
|
||||
}
|
||||
|
||||
func getCert(db nosql.DB, id string) (*certificate, error) {
|
||||
b, err := db.Get(certTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
|
||||
}
|
||||
var cert certificate
|
||||
if err := json.Unmarshal(b, &cert); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
@ -0,0 +1,253 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func defaultCertOps() (*CertOptions, error) {
|
||||
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CertOptions{
|
||||
AccountID: "accID",
|
||||
OrderID: "ordID",
|
||||
Leaf: crt,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newcert() (*certificate, error) {
|
||||
ops, err := defaultCertOps()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newCert(mockdb, *ops)
|
||||
}
|
||||
|
||||
func TestNewCert(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ops CertOptions
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := newCert(tc.db, tc.ops); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, *tc.id)
|
||||
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
|
||||
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: tc.ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range tc.ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, intermediates)
|
||||
|
||||
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCert(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
cert *certificate
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := getCert(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.ID, cert.ID)
|
||||
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
|
||||
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
|
||||
assert.Equals(t, tc.cert.Created, cert.Created)
|
||||
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
|
||||
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateToACME(t *testing.T) {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
acmeCert, err := cert.toACME(nil, nil)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
|
||||
}
|
@ -0,0 +1,445 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Challenge is a subset of the challenge type containing only those attributes
|
||||
// required for responses in the ACME protocol.
|
||||
type Challenge struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Validated string `json:"validated,omitempty"`
|
||||
URL string `json:"url"`
|
||||
Error *AError `json:"error,omitempty"`
|
||||
ID string `json:"-"`
|
||||
AuthzID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (c *Challenge) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Challenge ID.
|
||||
func (c *Challenge) GetID() string {
|
||||
return c.ID
|
||||
}
|
||||
|
||||
// GetAuthzID returns the parent Authz ID that owns the Challenge.
|
||||
func (c *Challenge) GetAuthzID() string {
|
||||
return c.AuthzID
|
||||
}
|
||||
|
||||
type httpGetter func(string) (*http.Response, error)
|
||||
type lookupTxt func(string) ([]string, error)
|
||||
|
||||
type validateOptions struct {
|
||||
httpGet httpGetter
|
||||
lookupTxt lookupTxt
|
||||
}
|
||||
|
||||
// challenge is the interface ACME challenege types must implement.
|
||||
type challenge interface {
|
||||
save(db nosql.DB, swap challenge) error
|
||||
validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error)
|
||||
getType() string
|
||||
getError() *AError
|
||||
getValue() string
|
||||
getStatus() string
|
||||
getID() string
|
||||
getAuthzID() string
|
||||
getToken() string
|
||||
clone() *baseChallenge
|
||||
getAccountID() string
|
||||
getValidated() time.Time
|
||||
getCreated() time.Time
|
||||
toACME(nosql.DB, *directory, provisioner.Interface) (*Challenge, error)
|
||||
}
|
||||
|
||||
// ChallengeOptions is the type used to created a new Challenge.
|
||||
type ChallengeOptions struct {
|
||||
AccountID string
|
||||
AuthzID string
|
||||
Identifier Identifier
|
||||
}
|
||||
|
||||
// baseChallenge is the base Challenge type that others build from.
|
||||
type baseChallenge struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
AuthzID string `json:"authzID"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Value string `json:"value"`
|
||||
Validated time.Time `json:"validated"`
|
||||
Created time.Time `json:"created"`
|
||||
Error *AError `json:"error"`
|
||||
}
|
||||
|
||||
func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error generating random id for ACME challenge")
|
||||
}
|
||||
token, err := randID()
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error generating token for ACME challenge")
|
||||
}
|
||||
|
||||
return &baseChallenge{
|
||||
ID: id,
|
||||
AccountID: accountID,
|
||||
AuthzID: authzID,
|
||||
Status: StatusPending,
|
||||
Token: token,
|
||||
Created: clock.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getID returns the id of the baseChallenge.
|
||||
func (bc *baseChallenge) getID() string {
|
||||
return bc.ID
|
||||
}
|
||||
|
||||
// getAuthzID returns the authz ID of the baseChallenge.
|
||||
func (bc *baseChallenge) getAuthzID() string {
|
||||
return bc.AuthzID
|
||||
}
|
||||
|
||||
// getAccountID returns the account id of the baseChallenge.
|
||||
func (bc *baseChallenge) getAccountID() string {
|
||||
return bc.AccountID
|
||||
}
|
||||
|
||||
// getType returns the type of the baseChallenge.
|
||||
func (bc *baseChallenge) getType() string {
|
||||
return bc.Type
|
||||
}
|
||||
|
||||
// getValue returns the type of the baseChallenge.
|
||||
func (bc *baseChallenge) getValue() string {
|
||||
return bc.Value
|
||||
}
|
||||
|
||||
// getStatus returns the status of the baseChallenge.
|
||||
func (bc *baseChallenge) getStatus() string {
|
||||
return bc.Status
|
||||
}
|
||||
|
||||
// getToken returns the token of the baseChallenge.
|
||||
func (bc *baseChallenge) getToken() string {
|
||||
return bc.Token
|
||||
}
|
||||
|
||||
// getValidated returns the validated time of the baseChallenge.
|
||||
func (bc *baseChallenge) getValidated() time.Time {
|
||||
return bc.Validated
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the baseChallenge.
|
||||
func (bc *baseChallenge) getCreated() time.Time {
|
||||
return bc.Created
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the baseChallenge.
|
||||
func (bc *baseChallenge) getError() *AError {
|
||||
return bc.Error
|
||||
}
|
||||
|
||||
// toACME converts the internal Challenge type into the public acmeChallenge
|
||||
// type for presentation in the ACME protocol.
|
||||
func (bc *baseChallenge) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Challenge, error) {
|
||||
ac := &Challenge{
|
||||
Type: bc.getType(),
|
||||
Status: bc.getStatus(),
|
||||
Token: bc.getToken(),
|
||||
URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()),
|
||||
ID: bc.getID(),
|
||||
AuthzID: bc.getAuthzID(),
|
||||
}
|
||||
if !bc.Validated.IsZero() {
|
||||
ac.Validated = bc.Validated.Format(time.RFC3339)
|
||||
}
|
||||
if bc.Error != nil {
|
||||
ac.Error = bc.Error
|
||||
}
|
||||
return ac, nil
|
||||
}
|
||||
|
||||
// save writes the challenge to disk. For new challenges 'old' should be nil,
|
||||
// otherwise 'old' should be a pointer to the acme challenge as it was at the
|
||||
// start of the request. This method will fail if the value currently found
|
||||
// in the bucket/row does not match the value of 'old'.
|
||||
func (bc *baseChallenge) save(db nosql.DB, old challenge) error {
|
||||
newB, err := json.Marshal(bc)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err,
|
||||
"error marshaling new acme challenge"))
|
||||
}
|
||||
var oldB []byte
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
oldB, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err,
|
||||
"error marshaling old acme challenge"))
|
||||
}
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error saving acme challenge"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error saving acme challenge; " +
|
||||
"acme challenge has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (bc *baseChallenge) clone() *baseChallenge {
|
||||
u := *bc
|
||||
return &u
|
||||
}
|
||||
|
||||
func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||
return nil, ServerInternalErr(errors.New("unimplemented"))
|
||||
}
|
||||
|
||||
func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error {
|
||||
clone := bc.clone()
|
||||
clone.Error = err.ToACME()
|
||||
return clone.save(db, bc)
|
||||
}
|
||||
|
||||
// unmarshalChallenge unmarshals a challenge type into the correct sub-type.
|
||||
func unmarshalChallenge(data []byte) (challenge, error) {
|
||||
var getType struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &getType); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type"))
|
||||
}
|
||||
|
||||
switch getType.Type {
|
||||
case "dns-01":
|
||||
var bc baseChallenge
|
||||
if err := json.Unmarshal(data, &bc); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
||||
"challenge type into dns01Challenge"))
|
||||
}
|
||||
return &dns01Challenge{&bc}, nil
|
||||
case "http-01":
|
||||
var bc baseChallenge
|
||||
if err := json.Unmarshal(data, &bc); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
||||
"challenge type into http01Challenge"))
|
||||
}
|
||||
return &http01Challenge{&bc}, nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// http01Challenge represents an http-01 acme challenge.
|
||||
type http01Challenge struct {
|
||||
*baseChallenge
|
||||
}
|
||||
|
||||
// newHTTP01Challenge returns a new acme http-01 challenge.
|
||||
func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bc.Type = "http-01"
|
||||
bc.Value = ops.Identifier.Value
|
||||
|
||||
hc := &http01Challenge{bc}
|
||||
if err := hc.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// Validate attempts to validate the challenge. If the challenge has been
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid {
|
||||
return hc, nil
|
||||
}
|
||||
url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token)
|
||||
|
||||
resp, err := vo.httpGet(url)
|
||||
if err != nil {
|
||||
if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err,
|
||||
"error doing http GET for url %s", url))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if err = hc.storeError(db,
|
||||
ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d",
|
||||
url, resp.StatusCode))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+
|
||||
"response body for url %s", url))
|
||||
}
|
||||
keyAuth := strings.Trim(string(body), "\r\n")
|
||||
|
||||
expected, err := KeyAuthorization(hc.Token, jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyAuth != expected {
|
||||
if err = hc.storeError(db,
|
||||
RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
|
||||
"expected %s, but got %s", expected, keyAuth))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// Update and store the challenge.
|
||||
upd := &http01Challenge{hc.baseChallenge.clone()}
|
||||
upd.Status = StatusValid
|
||||
upd.Error = nil
|
||||
upd.Validated = clock.Now()
|
||||
|
||||
if err := upd.save(db, hc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upd, nil
|
||||
}
|
||||
|
||||
// dns01Challenge represents an dns-01 acme challenge.
|
||||
type dns01Challenge struct {
|
||||
*baseChallenge
|
||||
}
|
||||
|
||||
// newDNS01Challenge returns a new acme dns-01 challenge.
|
||||
func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bc.Type = "dns-01"
|
||||
bc.Value = ops.Identifier.Value
|
||||
|
||||
dc := &dns01Challenge{bc}
|
||||
if err := dc.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
// KeyAuthorization creates the ACME key authorization value from a token
|
||||
// and a jwk.
|
||||
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
|
||||
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint"))
|
||||
}
|
||||
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
|
||||
return fmt.Sprintf("%s.%s", token, encPrint), nil
|
||||
}
|
||||
|
||||
// validate attempts to validate the challenge. If the challenge has been
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid {
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
txtRecords, err := vo.lookupTxt("_acme-challenge." + dc.Value)
|
||||
if err != nil {
|
||||
if err = dc.storeError(db,
|
||||
DNSErr(errors.Wrapf(err, "error looking up TXT "+
|
||||
"records for domain %s", dc.Value))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h := sha256.Sum256([]byte(expectedKeyAuth))
|
||||
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
||||
var found bool
|
||||
for _, r := range txtRecords {
|
||||
if r == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if err = dc.storeError(db,
|
||||
RejectedIdentifierErr(errors.Errorf("keyAuthorization "+
|
||||
"does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
// Update and store the challenge.
|
||||
upd := &dns01Challenge{dc.baseChallenge.clone()}
|
||||
upd.Status = StatusValid
|
||||
upd.Error = nil
|
||||
upd.Validated = time.Now().UTC()
|
||||
|
||||
if err := upd.save(db, dc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upd, nil
|
||||
}
|
||||
|
||||
// getChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||
func getChallenge(db nosql.DB, id string) (challenge, error) {
|
||||
b, err := db.Get(challengeTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id))
|
||||
}
|
||||
ch, err := unmarshalChallenge(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ch, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,76 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
)
|
||||
|
||||
// SignAuthority is the interface implemented by a CA authority.
|
||||
type SignAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Identifier encodes the type that an order pertains to.
|
||||
type Identifier struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
var (
|
||||
accountTable = []byte("acme-accounts")
|
||||
accountByKeyIDTable = []byte("acme-keyID-accountID-index")
|
||||
authzTable = []byte("acme-authzs")
|
||||
challengeTable = []byte("acme-challenges")
|
||||
nonceTable = []byte("nonce-table")
|
||||
orderTable = []byte("acme-orders")
|
||||
ordersByAccountIDTable = []byte("acme-account-orders-index")
|
||||
certTable = []byte("acme-certs")
|
||||
)
|
||||
|
||||
var (
|
||||
// StatusValid -- valid
|
||||
StatusValid = "valid"
|
||||
// StatusInvalid -- invalid
|
||||
StatusInvalid = "invalid"
|
||||
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||
StatusPending = "pending"
|
||||
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||
StatusDeactivated = "deactivated"
|
||||
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||
StatusReady = "ready"
|
||||
//statusExpired = "expired"
|
||||
//statusActive = "active"
|
||||
//statusProcessing = "processing"
|
||||
)
|
||||
|
||||
var idLen = 32
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.Alphanumeric(idLen)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock int
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Round(time.Second)
|
||||
}
|
||||
|
||||
var clock = new(Clock)
|
||||
|
||||
// URLSafeProvisionerName returns a path escaped version of the ACME provisioner
|
||||
// ID that is safe to use in URL paths.
|
||||
func URLSafeProvisionerName(p provisioner.Interface) string {
|
||||
return url.PathEscape(p.GetName())
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Directory represents an ACME directory for configuring clients.
|
||||
type Directory struct {
|
||||
NewNonce string `json:"newNonce,omitempty"`
|
||||
NewAccount string `json:"newAccount,omitempty"`
|
||||
NewOrder string `json:"newOrder,omitempty"`
|
||||
NewAuthz string `json:"newAuthz,omitempty"`
|
||||
RevokeCert string `json:"revokeCert,omitempty"`
|
||||
KeyChange string `json:"keyChange,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging for the Directory type.
|
||||
func (d *Directory) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
type directory struct {
|
||||
prefix, dns string
|
||||
}
|
||||
|
||||
// newDirectory returns a new Directory type.
|
||||
func newDirectory(dns, prefix string) *directory {
|
||||
return &directory{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Link captures the link type.
|
||||
type Link int
|
||||
|
||||
const (
|
||||
// NewNonceLink new-nonce
|
||||
NewNonceLink Link = iota
|
||||
// NewAccountLink new-account
|
||||
NewAccountLink
|
||||
// AccountLink account
|
||||
AccountLink
|
||||
// OrderLink order
|
||||
OrderLink
|
||||
// NewOrderLink new-order
|
||||
NewOrderLink
|
||||
// OrdersByAccountLink list of orders owned by account
|
||||
OrdersByAccountLink
|
||||
// FinalizeLink finalize order
|
||||
FinalizeLink
|
||||
// NewAuthzLink authz
|
||||
NewAuthzLink
|
||||
// AuthzLink new-authz
|
||||
AuthzLink
|
||||
// ChallengeLink challenge
|
||||
ChallengeLink
|
||||
// CertificateLink certificate
|
||||
CertificateLink
|
||||
// DirectoryLink directory
|
||||
DirectoryLink
|
||||
// RevokeCertLink revoke certificate
|
||||
RevokeCertLink
|
||||
// KeyChangeLink key rollover
|
||||
KeyChangeLink
|
||||
)
|
||||
|
||||
func (l Link) String() string {
|
||||
switch l {
|
||||
case NewNonceLink:
|
||||
return "new-nonce"
|
||||
case NewAccountLink:
|
||||
return "new-account"
|
||||
case AccountLink:
|
||||
return "account"
|
||||
case NewOrderLink:
|
||||
return "new-order"
|
||||
case OrderLink:
|
||||
return "order"
|
||||
case NewAuthzLink:
|
||||
return "new-authz"
|
||||
case AuthzLink:
|
||||
return "authz"
|
||||
case ChallengeLink:
|
||||
return "challenge"
|
||||
case CertificateLink:
|
||||
return "certificate"
|
||||
case DirectoryLink:
|
||||
return "directory"
|
||||
case RevokeCertLink:
|
||||
return "revoke-cert"
|
||||
case KeyChangeLink:
|
||||
return "key-change"
|
||||
default:
|
||||
return "unexpected"
|
||||
}
|
||||
}
|
||||
|
||||
// getLink returns an absolute or partial path to the given resource.
|
||||
func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string {
|
||||
var link string
|
||||
switch typ {
|
||||
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
|
||||
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
|
||||
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
|
||||
case OrdersByAccountLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
|
||||
case FinalizeLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
|
||||
}
|
||||
if abs {
|
||||
return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link)
|
||||
}
|
||||
return link
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestDirectoryGetLink(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
dir := newDirectory(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provID := URLSafeProvisionerName(prov)
|
||||
|
||||
type newTest struct {
|
||||
actual, expected string
|
||||
}
|
||||
assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID))
|
||||
assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID))
|
||||
assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID))
|
||||
assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID))
|
||||
assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID))
|
||||
assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID))
|
||||
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID))
|
||||
assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID))
|
||||
assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID))
|
||||
assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID))
|
||||
assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID))
|
||||
assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID))
|
||||
assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID))
|
||||
assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID))
|
||||
assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID))
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// nonce contains nonce metadata used in the ACME protocol.
|
||||
type nonce struct {
|
||||
ID string
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// newNonce creates, stores, and returns an ACME replay-nonce.
|
||||
func newNonce(db nosql.DB) (*nonce, error) {
|
||||
_id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||
n := &nonce{
|
||||
ID: id,
|
||||
Created: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing nonce; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
// useNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func useNonce(db nosql.DB, nonce string) error {
|
||||
err := db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Get,
|
||||
},
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Delete,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
return BadNonceErr(nil)
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
@ -0,0 +1,163 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestNewNonce(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if n, err := newNonce(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, n.ID, *tc.id)
|
||||
|
||||
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseNonce(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/update-not-found": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return database.ErrNotFound
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: BadNonceErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/update-error": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := useNonce(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,342 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
var defaultOrderExpiry = time.Hour * 24
|
||||
|
||||
// Order contains order metadata for the ACME protocol order type.
|
||||
type Order struct {
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires,omitempty"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore string `json:"notBefore,omitempty"`
|
||||
NotAfter string `json:"notAfter,omitempty"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
Authorizations []string `json:"authorizations"`
|
||||
Finalize string `json:"finalize"`
|
||||
Certificate string `json:"certificate,omitempty"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (o *Order) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Order ID.
|
||||
func (o *Order) GetID() string {
|
||||
return o.ID
|
||||
}
|
||||
|
||||
// OrderOptions options with which to create a new Order.
|
||||
type OrderOptions struct {
|
||||
AccountID string `json:"accID"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
}
|
||||
|
||||
type order struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Created time.Time `json:"created"`
|
||||
Expires time.Time `json:"expires,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Authorizations []string `json:"authorizations"`
|
||||
Certificate string `json:"certificate,omitempty"`
|
||||
}
|
||||
|
||||
// newOrder returns a new Order type.
|
||||
func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authzs := make([]string, len(ops.Identifiers))
|
||||
for i, identifier := range ops.Identifiers {
|
||||
authz, err := newAuthz(db, ops.AccountID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authzs[i] = authz.getID()
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
o := &order{
|
||||
ID: id,
|
||||
AccountID: ops.AccountID,
|
||||
Created: now,
|
||||
Status: StatusPending,
|
||||
Expires: now.Add(defaultOrderExpiry),
|
||||
Identifiers: ops.Identifiers,
|
||||
NotBefore: ops.NotBefore,
|
||||
NotAfter: ops.NotAfter,
|
||||
Authorizations: authzs,
|
||||
}
|
||||
if err := o.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update the "order IDs by account ID" index //
|
||||
oids, err := getOrderIDsByAccount(db, ops.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newOids := append(oids, o.ID)
|
||||
if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil {
|
||||
db.Del(orderTable, []byte(o.ID))
|
||||
return nil, err
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
type orderIDs []string
|
||||
|
||||
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
|
||||
var (
|
||||
err error
|
||||
oldb []byte
|
||||
)
|
||||
if len(old) == 0 {
|
||||
oldb = nil
|
||||
} else {
|
||||
oldb, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
|
||||
}
|
||||
}
|
||||
newb, err := json.Marshal(oids)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing order IDs "+
|
||||
"for account %s; order IDs changed since last read", accID))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *order) save(db nosql.DB, old *order) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
newB, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing order"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing order; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// updateStatus updates order status if necessary.
|
||||
func (o *order) updateStatus(db nosql.DB) (*order, error) {
|
||||
_newOrder := *o
|
||||
newOrder := &_newOrder
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch o.Status {
|
||||
case StatusInvalid:
|
||||
return o, nil
|
||||
case StatusValid:
|
||||
return o, nil
|
||||
case StatusReady:
|
||||
// check expiry
|
||||
if now.After(o.Expires) {
|
||||
newOrder.Status = StatusInvalid
|
||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||
break
|
||||
}
|
||||
return o, nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(o.Expires) {
|
||||
newOrder.Status = StatusInvalid
|
||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var count = map[string]int{
|
||||
StatusValid: 0,
|
||||
StatusInvalid: 0,
|
||||
StatusPending: 0,
|
||||
}
|
||||
for _, azID := range o.Authorizations {
|
||||
authz, err := getAuthz(db, azID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz, err = authz.updateStatus(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st := authz.getStatus()
|
||||
count[st]++
|
||||
}
|
||||
switch {
|
||||
case count[StatusInvalid] > 0:
|
||||
newOrder.Status = StatusInvalid
|
||||
case count[StatusPending] > 0:
|
||||
break
|
||||
case count[StatusValid] == len(o.Authorizations):
|
||||
newOrder.Status = StatusReady
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.New("unexpected authz status"))
|
||||
}
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
|
||||
}
|
||||
|
||||
if err := newOrder.save(db, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newOrder, nil
|
||||
}
|
||||
|
||||
// finalize signs a certificate if the necessary conditions for Order completion
|
||||
// have been met.
|
||||
func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) {
|
||||
var err error
|
||||
if o, err = o.updateStatus(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch o.Status {
|
||||
case StatusInvalid:
|
||||
return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
|
||||
case StatusValid:
|
||||
return o, nil
|
||||
case StatusPending:
|
||||
return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
|
||||
case StatusReady:
|
||||
break
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
|
||||
}
|
||||
|
||||
// Validate identifier names against CSR alternative names //
|
||||
csrNames := make(map[string]int)
|
||||
for _, n := range csr.DNSNames {
|
||||
csrNames[n] = 1
|
||||
}
|
||||
orderNames := make(map[string]int)
|
||||
for _, n := range o.Identifiers {
|
||||
orderNames[n.Value] = 1
|
||||
}
|
||||
if !reflect.DeepEqual(csrNames, orderNames) {
|
||||
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly"))
|
||||
}
|
||||
|
||||
// Get authorizations from the ACME provisioner.
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
signOps, err := p.AuthorizeSign(ctx, "")
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
|
||||
}
|
||||
|
||||
// Create and store a new certificate.
|
||||
leaf, inter, err := auth.Sign(csr, provisioner.Options{
|
||||
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
|
||||
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
|
||||
}, signOps...)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
|
||||
}
|
||||
|
||||
cert, err := newCert(db, CertOptions{
|
||||
AccountID: o.AccountID,
|
||||
OrderID: o.ID,
|
||||
Leaf: leaf,
|
||||
Intermediates: []*x509.Certificate{inter},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_newOrder := *o
|
||||
newOrder := &_newOrder
|
||||
newOrder.Certificate = cert.ID
|
||||
newOrder.Status = StatusValid
|
||||
if err := newOrder.save(db, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newOrder, nil
|
||||
}
|
||||
|
||||
// getOrder retrieves and unmarshals an ACME Order type from the database.
|
||||
func getOrder(db nosql.DB, id string) (*order, error) {
|
||||
b, err := db.Get(orderTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
|
||||
}
|
||||
var o order
|
||||
if err := json.Unmarshal(b, &o); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
|
||||
}
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// toACME converts the internal Order type into the public acmeOrder type for
|
||||
// presentation in the ACME protocol.
|
||||
func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) {
|
||||
azs := make([]string, len(o.Authorizations))
|
||||
for i, aid := range o.Authorizations {
|
||||
azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid)
|
||||
}
|
||||
ao := &Order{
|
||||
Status: o.Status,
|
||||
Expires: o.Expires.Format(time.RFC3339),
|
||||
Identifiers: o.Identifiers,
|
||||
NotBefore: o.NotBefore.Format(time.RFC3339),
|
||||
NotAfter: o.NotAfter.Format(time.RFC3339),
|
||||
Authorizations: azs,
|
||||
Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID),
|
||||
ID: o.ID,
|
||||
}
|
||||
|
||||
if o.Certificate != "" {
|
||||
ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate)
|
||||
}
|
||||
return ao, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,85 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
||||
// provisioning flow.
|
||||
type ACME struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
claimer *Claimer
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (p ACME) GetID() string {
|
||||
return "acme/" + p.Name
|
||||
}
|
||||
|
||||
// GetTokenID returns the identifier of the token.
|
||||
func (p *ACME) GetTokenID(ott string) (string, error) {
|
||||
return "", errors.New("acme provisioner does not implement GetTokenID")
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
func (p *ACME) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
||||
// GetType returns the type of provisioner.
|
||||
func (p *ACME) GetType() Type {
|
||||
return TypeACME
|
||||
}
|
||||
|
||||
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
|
||||
func (p *ACME) GetEncryptedKey() (string, string, bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of a JWK type.
|
||||
func (p *ACME) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
case p.Name == "":
|
||||
return errors.New("provisioner name cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AuthorizeRevoke is not implemented yet for the ACME provisioner.
|
||||
func (p *ACME) AuthorizeRevoke(token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *ACME) AuthorizeSign(ctx context.Context, _ string) ([]SignOption, error) {
|
||||
if m := MethodFromContext(ctx); m != SignMethod {
|
||||
return nil, errors.Errorf("unexpected method type %d in context", m)
|
||||
}
|
||||
return []SignOption{
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
defaultPublicKeyValidator{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal is not implemented for the ACME provisioner.
|
||||
func (p *ACME) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,184 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestACME_Getters(t *testing.T) {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
id := "acme/" + p.Name
|
||||
if got := p.GetID(); got != id {
|
||||
t.Errorf("ACME.GetID() = %v, want %v", got, id)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("ACME.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeACME {
|
||||
t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_Init(t *testing.T) {
|
||||
type ProvisionerValidateTest struct {
|
||||
p *ACME
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{
|
||||
Type: "ACME",
|
||||
},
|
||||
err: errors.New("provisioner name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo"},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo", Type: "bar"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}
|
||||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.p.Init(config)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeRevoke(t *testing.T) {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
assert.Nil(t, p.AuthorizeRevoke(""))
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *ACME
|
||||
args args
|
||||
err error
|
||||
}{
|
||||
{"ok", p1, args{nil}, nil},
|
||||
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(tt.args.cert); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeSign(t *testing.T) {
|
||||
p1, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *ACME
|
||||
method Method
|
||||
err error
|
||||
}{
|
||||
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
|
||||
{"ok", p1, SignMethod, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), tt.method)
|
||||
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, got) {
|
||||
assert.Len(t, 4, got)
|
||||
|
||||
_pdd := got[0]
|
||||
pdd, ok := _pdd.(profileDefaultDuration)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, pdd, profileDefaultDuration(86400000000000))
|
||||
|
||||
_peo := got[1]
|
||||
peo, ok := _peo.(*provisionerExtensionOption)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, peo.Type, 6)
|
||||
assert.Equals(t, peo.Name, "test@acme-provisioner.com")
|
||||
assert.Equals(t, peo.CredentialID, "")
|
||||
assert.Equals(t, peo.KeyValuePairs, nil)
|
||||
|
||||
_vv := got[2]
|
||||
vv, ok := _vv.(*validityValidator)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, vv.min, time.Duration(300000000000))
|
||||
assert.Equals(t, vv.max, time.Duration(86400000000000))
|
||||
|
||||
_dpkv := got[3]
|
||||
_, ok = _dpkv.(defaultPublicKeyValidator)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
-----BEGIN CERTIFICATE REQUEST-----
|
||||
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
|
||||
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
|
||||
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
|
||||
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
|
||||
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
|
||||
OI+cWOIc/IGwqZul/zEF5dani5ihOL7UwA==
|
||||
-----END CERTIFICATE REQUEST-----
|
@ -0,0 +1,8 @@
|
||||
-----BEGIN CERTIFICATE REQUEST-----
|
||||
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
|
||||
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
|
||||
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
|
||||
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
|
||||
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
|
||||
OI+cWOIc/IGwqZul/zEF5dani5ihOR7UwA==
|
||||
-----END CERTIFICATE REQUEST-----
|
@ -0,0 +1,354 @@
|
||||
package ca
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// ACMEClient implements an HTTP client to an ACME API.
|
||||
type ACMEClient struct {
|
||||
client *http.Client
|
||||
dirLoc string
|
||||
dir *acme.Directory
|
||||
acc *acme.Account
|
||||
Key *jose.JSONWebKey
|
||||
kid string
|
||||
}
|
||||
|
||||
// NewACMEClient initializes a new ACMEClient.
|
||||
func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) {
|
||||
// Retrieve transport from options.
|
||||
o := new(clientOptions)
|
||||
if err := o.apply(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr, err := o.getTransport(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ac := &ACMEClient{
|
||||
client: &http.Client{
|
||||
Transport: tr,
|
||||
},
|
||||
dirLoc: endpoint,
|
||||
}
|
||||
|
||||
resp, err := ac.client.Get(endpoint)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", endpoint)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
var dir acme.Directory
|
||||
if err := readJSON(resp.Body, &dir); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", endpoint)
|
||||
}
|
||||
|
||||
ac.dir = &dir
|
||||
|
||||
ac.Key, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nar := &acmeAPI.NewAccountRequest{
|
||||
Contact: contact,
|
||||
TermsOfServiceAgreed: true,
|
||||
}
|
||||
payload, err := json.Marshal(nar)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling new account request")
|
||||
}
|
||||
|
||||
resp, err = ac.post(payload, ac.dir.NewAccount, withJWK(ac))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
var acc acme.Account
|
||||
if err := readJSON(resp.Body, &acc); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", dir.NewAccount)
|
||||
}
|
||||
ac.acc = &acc
|
||||
ac.kid = resp.Header.Get("Location")
|
||||
|
||||
return ac, nil
|
||||
}
|
||||
|
||||
// GetDirectory makes a directory request to the ACME api and returns an
|
||||
// ACME directory object.
|
||||
func (c *ACMEClient) GetDirectory() (*acme.Directory, error) {
|
||||
return c.dir, nil
|
||||
}
|
||||
|
||||
// GetNonce makes a nonce request to the ACME api and returns an
|
||||
// ACME directory object.
|
||||
func (c *ACMEClient) GetNonce() (string, error) {
|
||||
resp, err := c.client.Get(c.dir.NewNonce)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", readACMEError(resp.Body)
|
||||
}
|
||||
return resp.Header.Get("Replay-Nonce"), nil
|
||||
}
|
||||
|
||||
type withHeaderOption func(so *jose.SignerOptions)
|
||||
|
||||
func withJWK(c *ACMEClient) withHeaderOption {
|
||||
return func(so *jose.SignerOptions) {
|
||||
so.WithHeader("jwk", c.Key.Public())
|
||||
}
|
||||
}
|
||||
|
||||
func withKid(c *ACMEClient) withHeaderOption {
|
||||
return func(so *jose.SignerOptions) {
|
||||
so.WithHeader("kid", c.kid)
|
||||
}
|
||||
}
|
||||
|
||||
// serialize serializes a json web signature and doesn't omit empty fields.
|
||||
func serialize(obj *jose.JSONWebSignature) (string, error) {
|
||||
raw, err := obj.CompactSerialize()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error serializing JWS")
|
||||
}
|
||||
parts := strings.Split(raw, ".")
|
||||
msg := struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Signature string `json:"signature"`
|
||||
}{Protected: parts[0], Payload: parts[1], Signature: parts[2]}
|
||||
b, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error marshaling jws message")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOption) (*http.Response, error) {
|
||||
if c.Key == nil {
|
||||
return nil, errors.New("acme client not configured with account")
|
||||
}
|
||||
nonce, err := c.GetNonce()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("nonce", nonce)
|
||||
so.WithHeader("url", url)
|
||||
for _, hop := range headerOps {
|
||||
hop(so)
|
||||
}
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(c.Key.Algorithm),
|
||||
Key: c.Key.Key,
|
||||
}, so)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error creating JWS signer")
|
||||
}
|
||||
signed, err := signer.Sign(payload)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("error signing payload: %s", strings.TrimPrefix(err.Error(), "square/go-jose: "))
|
||||
}
|
||||
raw, err := serialize(signed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.client.Post(url, "application/jose+json", strings.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", c.dir.NewOrder)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewOrder creates and returns the information for a new ACME order.
|
||||
func (c *ACMEClient) NewOrder(payload []byte) (*acme.Order, error) {
|
||||
resp, err := c.post(payload, c.dir.NewOrder, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var o acme.Order
|
||||
if err := readJSON(resp.Body, &o); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", c.dir.NewOrder)
|
||||
}
|
||||
o.ID = resp.Header.Get("Location")
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// GetChallenge returns the Challenge at the given path.
|
||||
// With the validate parameter set to True this method will attempt to validate the
|
||||
// challenge before returning it.
|
||||
func (c *ACMEClient) GetChallenge(url string) (*acme.Challenge, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var ch acme.Challenge
|
||||
if err := readJSON(resp.Body, &ch); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||
}
|
||||
return &ch, nil
|
||||
}
|
||||
|
||||
// ValidateChallenge returns the Challenge at the given path.
|
||||
// With the validate parameter set to True this method will attempt to validate the
|
||||
// challenge before returning it.
|
||||
func (c *ACMEClient) ValidateChallenge(url string) error {
|
||||
resp, err := c.post([]byte("{}"), url, withKid(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return readACMEError(resp.Body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthz returns the Authz at the given path.
|
||||
func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var az acme.Authz
|
||||
if err := readJSON(resp.Body, &az); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||
}
|
||||
return &az, nil
|
||||
}
|
||||
|
||||
// GetOrder returns the Order at the given path.
|
||||
func (c *ACMEClient) GetOrder(url string) (*acme.Order, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var o acme.Order
|
||||
if err := readJSON(resp.Body, &o); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||
}
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// FinalizeOrder makes a finalize request to the ACME api.
|
||||
func (c *ACMEClient) FinalizeOrder(url string, csr *x509.CertificateRequest) error {
|
||||
payload, err := json.Marshal(acmeAPI.FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling finalize request")
|
||||
}
|
||||
resp, err := c.post(payload, url, withKid(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return readACMEError(resp.Body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificate retrieves the certificate along with all intermediates.
|
||||
func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Certificate, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, nil, readACMEError(resp.Body)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error reading GET certificate response")
|
||||
}
|
||||
|
||||
var certs []*x509.Certificate
|
||||
|
||||
block, rest := pem.Decode(bodyBytes)
|
||||
if block == nil {
|
||||
return nil, nil, errors.New("failed to parse any certificates from response")
|
||||
}
|
||||
for block != nil {
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing certificate pem response")
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
block, rest = pem.Decode(rest)
|
||||
}
|
||||
|
||||
return certs[0], certs[1:], nil
|
||||
}
|
||||
|
||||
// GetAccountOrders retrieves the orders belonging to the given account.
|
||||
func (c *ACMEClient) GetAccountOrders() ([]string, error) {
|
||||
if c.acc == nil {
|
||||
return nil, errors.New("acme client not configured with account")
|
||||
}
|
||||
resp, err := c.post(nil, c.acc.Orders, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var orders []string
|
||||
if err := readJSON(resp.Body, &orders); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders)
|
||||
}
|
||||
|
||||
return orders, nil
|
||||
}
|
||||
|
||||
func readACMEError(r io.ReadCloser) error {
|
||||
defer r.Close()
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error reading from body")
|
||||
}
|
||||
ae := new(acme.AError)
|
||||
err = json.Unmarshal(b, &ae)
|
||||
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
||||
if err != nil || len(ae.Error()) == 0 {
|
||||
fmt.Printf("b = %s\n", b)
|
||||
// Throw up our hands.
|
||||
return errors.Errorf("%s", b)
|
||||
}
|
||||
return ae
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue