Merge branch 'master' into herman/allow-deny
commit
c1424036bf
@ -0,0 +1,155 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// ExternalAccountBinding represents the ACME externalAccountBinding JWS
|
||||
type ExternalAccountBinding struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Sig string `json:"signature"`
|
||||
}
|
||||
|
||||
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
||||
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
||||
}
|
||||
|
||||
if !acmeProv.RequireEAB {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if nar.ExternalAccountBinding == nil {
|
||||
return nil, acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided")
|
||||
}
|
||||
|
||||
eabJSONBytes, err := json.Marshal(nar.ExternalAccountBinding)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error marshaling externalAccountBinding into bytes")
|
||||
}
|
||||
|
||||
eabJWS, err := jose.ParseJWS(string(eabJSONBytes))
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error parsing externalAccountBinding jws")
|
||||
}
|
||||
|
||||
// TODO(hs): implement strategy pattern to allow for different ways of verification (i.e. webhook call) based on configuration?
|
||||
|
||||
keyID, acmeErr := validateEABJWS(ctx, eabJWS)
|
||||
if acmeErr != nil {
|
||||
return nil, acmeErr
|
||||
}
|
||||
|
||||
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
if err != nil {
|
||||
if _, ok := err.(*acme.Error); ok {
|
||||
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
||||
}
|
||||
return nil, acme.WrapErrorISE(err, "error retrieving external account key")
|
||||
}
|
||||
|
||||
if externalAccountKey.AlreadyBound() {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
|
||||
}
|
||||
|
||||
payload, err := eabJWS.Verify(externalAccountKey.KeyBytes)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
|
||||
}
|
||||
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var payloadJWK *jose.JSONWebKey
|
||||
if err = json.Unmarshal(payload, &payloadJWK); err != nil {
|
||||
return nil, acme.WrapError(acme.ErrorMalformedType, err, "error unmarshaling payload into jwk")
|
||||
}
|
||||
|
||||
if !keysAreEqual(jwk, payloadJWK) {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "keys in jws and eab payload do not match")
|
||||
}
|
||||
|
||||
return externalAccountKey, nil
|
||||
}
|
||||
|
||||
// keysAreEqual performs an equality check on two JWKs by comparing
|
||||
// the (base64 encoding) of the Key IDs.
|
||||
func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
||||
if x == nil || y == nil {
|
||||
return false
|
||||
}
|
||||
digestX, errX := acme.KeyToID(x)
|
||||
digestY, errY := acme.KeyToID(y)
|
||||
if errX != nil || errY != nil {
|
||||
return false
|
||||
}
|
||||
return digestX == digestY
|
||||
}
|
||||
|
||||
// validateEABJWS verifies the contents of the External Account Binding JWS.
|
||||
// The protected header of the JWS MUST meet the following criteria:
|
||||
// o The "alg" field MUST indicate a MAC-based algorithm
|
||||
// o The "kid" field MUST contain the key identifier provided by the CA
|
||||
// o The "nonce" field MUST NOT be present
|
||||
// o The "url" field MUST be set to the same value as the outer JWS
|
||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||
|
||||
if jws == nil {
|
||||
return "", acme.NewErrorISE("no JWS provided")
|
||||
}
|
||||
|
||||
if len(jws.Signatures) != 1 {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "JWS must have one signature")
|
||||
}
|
||||
|
||||
header := jws.Signatures[0].Protected
|
||||
algorithm := header.Algorithm
|
||||
keyID := header.KeyID
|
||||
nonce := header.Nonce
|
||||
|
||||
if !(algorithm == jose.HS256 || algorithm == jose.HS384 || algorithm == jose.HS512) {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'alg' field set to invalid algorithm '%s'", algorithm)
|
||||
}
|
||||
|
||||
if keyID == "" {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'kid' field is required")
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'nonce' must not be present")
|
||||
}
|
||||
|
||||
jwsURL, ok := header.ExtraHeaders["url"]
|
||||
if !ok {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'url' field is required")
|
||||
}
|
||||
|
||||
outerJWS, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
return "", acme.WrapErrorISE(err, "could not retrieve outer JWS from context")
|
||||
}
|
||||
|
||||
if len(outerJWS.Signatures) != 1 {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature")
|
||||
}
|
||||
|
||||
outerJWSURL, ok := outerJWS.Signatures[0].Protected.ExtraHeaders["url"]
|
||||
if !ok {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'url' field must be set in outer JWS")
|
||||
}
|
||||
|
||||
if jwsURL != outerJWSURL {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'url' field is not the same value as the outer JWS")
|
||||
}
|
||||
|
||||
return keyID, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,380 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// externalAccountKeyMutex for read/write locking of EAK operations.
|
||||
var externalAccountKeyMutex sync.RWMutex
|
||||
|
||||
// referencesByProvisionerIndexMutex for locking referencesByProvisioner index operations.
|
||||
var referencesByProvisionerIndexMutex sync.Mutex
|
||||
|
||||
type dbExternalAccountKey struct {
|
||||
ID string `json:"id"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"accountID,omitempty"`
|
||||
KeyBytes []byte `json:"key"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt"`
|
||||
}
|
||||
|
||||
type dbExternalAccountKeyReference struct {
|
||||
Reference string `json:"reference"`
|
||||
ExternalAccountKeyID string `json:"externalAccountKeyID"`
|
||||
}
|
||||
|
||||
// getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey.
|
||||
func (db *DB) getDBExternalAccountKey(ctx context.Context, id string) (*dbExternalAccountKey, error) {
|
||||
data, err := db.db.Get(externalAccountKeyTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return nil, acme.ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrapf(err, "error loading external account key %s", id)
|
||||
}
|
||||
|
||||
dbeak := new(dbExternalAccountKey)
|
||||
if err = json.Unmarshal(data, dbeak); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling external account key %s into dbExternalAccountKey", id)
|
||||
}
|
||||
|
||||
return dbeak, nil
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey creates a new External Account Binding key with a name
|
||||
func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
keyID, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
random := make([]byte, 32)
|
||||
_, err = rand.Read(random)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbeak := &dbExternalAccountKey{
|
||||
ID: keyID,
|
||||
ProvisionerID: provisionerID,
|
||||
Reference: reference,
|
||||
KeyBytes: random,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
if err := db.save(ctx, keyID, dbeak, nil, "external_account_key", externalAccountKeyTable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.addEAKID(ctx, provisionerID, dbeak.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbeak.Reference != "" {
|
||||
dbExternalAccountKeyReference := &dbExternalAccountKeyReference{
|
||||
Reference: dbeak.Reference,
|
||||
ExternalAccountKeyID: dbeak.ID,
|
||||
}
|
||||
if err := db.save(ctx, referenceKey(provisionerID, dbeak.Reference), dbExternalAccountKeyReference, nil, "external_account_key_reference", externalAccountKeyIDsByReferenceTable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &acme.ExternalAccountKey{
|
||||
ID: dbeak.ID,
|
||||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
KeyBytes: dbeak.KeyBytes,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKey retrieves an External Account Binding key by KeyID
|
||||
func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
dbeak, err := db.getDBExternalAccountKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbeak.ProvisionerID != provisionerID {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created")
|
||||
}
|
||||
|
||||
return &acme.ExternalAccountKey{
|
||||
ID: dbeak.ID,
|
||||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
KeyBytes: dbeak.KeyBytes,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
dbeak, err := db.getDBExternalAccountKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error loading ACME EAB Key with Key ID %s", keyID)
|
||||
}
|
||||
|
||||
if dbeak.ProvisionerID != provisionerID {
|
||||
return errors.New("provisioner does not match provisioner for which the EAB key was created")
|
||||
}
|
||||
|
||||
if dbeak.Reference != "" {
|
||||
if err := db.db.Del(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, dbeak.Reference))); err != nil {
|
||||
return errors.Wrapf(err, "error deleting ACME EAB Key reference with Key ID %s and reference %s", keyID, dbeak.Reference)
|
||||
}
|
||||
}
|
||||
if err := db.db.Del(externalAccountKeyTable, []byte(keyID)); err != nil {
|
||||
return errors.Wrapf(err, "error deleting ACME EAB Key with Key ID %s", keyID)
|
||||
}
|
||||
if err := db.deleteEAKID(ctx, provisionerID, keyID); err != nil {
|
||||
return errors.Wrapf(err, "error removing ACME EAB Key ID %s", keyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner
|
||||
func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) {
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
// cursor and limit are ignored in open source, at least for now.
|
||||
|
||||
var eakIDs []string
|
||||
r, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID))
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return nil, "", errors.Wrapf(err, "error loading ACME EAB Key IDs for provisioner %s", provisionerID)
|
||||
}
|
||||
// it may happen that no record is found; we'll continue with an empty slice
|
||||
} else {
|
||||
if err := json.Unmarshal(r, &eakIDs); err != nil {
|
||||
return nil, "", errors.Wrapf(err, "error unmarshaling ACME EAB Key IDs for provisioner %s", provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
keys := []*acme.ExternalAccountKey{}
|
||||
for _, eakID := range eakIDs {
|
||||
if eakID == "" {
|
||||
continue // shouldn't happen; just in case
|
||||
}
|
||||
eak, err := db.getDBExternalAccountKey(ctx, eakID)
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return nil, "", errors.Wrapf(err, "error retrieving ACME EAB Key for provisioner %s and keyID %s", provisionerID, eakID)
|
||||
}
|
||||
}
|
||||
keys = append(keys, &acme.ExternalAccountKey{
|
||||
ID: eak.ID,
|
||||
KeyBytes: eak.KeyBytes,
|
||||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
CreatedAt: eak.CreatedAt,
|
||||
BoundAt: eak.BoundAt,
|
||||
})
|
||||
}
|
||||
|
||||
return keys, "", nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKeyByReference retrieves an External Account Binding key with unique reference
|
||||
func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
if reference == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
k, err := db.db.Get(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, reference)))
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return nil, acme.ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading ACME EAB key for reference %s", reference)
|
||||
}
|
||||
dbExternalAccountKeyReference := new(dbExternalAccountKeyReference)
|
||||
if err := json.Unmarshal(k, dbExternalAccountKeyReference); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling ACME EAB key for reference %s", reference)
|
||||
}
|
||||
|
||||
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
old, err := db.getDBExternalAccountKey(ctx, eak.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if old.ProvisionerID != provisionerID {
|
||||
return errors.New("provisioner does not match provisioner for which the EAB key was created")
|
||||
}
|
||||
|
||||
if old.ProvisionerID != eak.ProvisionerID {
|
||||
return errors.New("cannot change provisioner for an existing ACME EAB Key")
|
||||
}
|
||||
|
||||
if old.Reference != eak.Reference {
|
||||
return errors.New("cannot change reference for an existing ACME EAB Key")
|
||||
}
|
||||
|
||||
nu := dbExternalAccountKey{
|
||||
ID: eak.ID,
|
||||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
KeyBytes: eak.KeyBytes,
|
||||
CreatedAt: eak.CreatedAt,
|
||||
BoundAt: eak.BoundAt,
|
||||
}
|
||||
|
||||
return db.save(ctx, nu.ID, nu, old, "external_account_key", externalAccountKeyTable)
|
||||
}
|
||||
|
||||
func (db *DB) addEAKID(ctx context.Context, provisionerID, eakID string) error {
|
||||
referencesByProvisionerIndexMutex.Lock()
|
||||
defer referencesByProvisionerIndexMutex.Unlock()
|
||||
|
||||
if eakID == "" {
|
||||
return errors.Errorf("can't add empty eakID for provisioner %s", provisionerID)
|
||||
}
|
||||
|
||||
var eakIDs []string
|
||||
b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID))
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
// it may happen that no record is found; we'll continue with an empty slice
|
||||
} else {
|
||||
if err := json.Unmarshal(b, &eakIDs); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range eakIDs {
|
||||
if id == eakID {
|
||||
// return an error when a duplicate ID is found
|
||||
return errors.Errorf("eakID %s already exists for provisioner %s", eakID, provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
var newEAKIDs []string
|
||||
newEAKIDs = append(newEAKIDs, eakIDs...)
|
||||
newEAKIDs = append(newEAKIDs, eakID)
|
||||
|
||||
var (
|
||||
_old interface{} = eakIDs
|
||||
_new interface{} = newEAKIDs
|
||||
)
|
||||
|
||||
// ensure that the DB gets the expected value when the slice is empty; otherwise
|
||||
// it'll return with an error that indicates that the DBs view of the data is
|
||||
// different from the last read (i.e. _old is different from what the DB has).
|
||||
if len(eakIDs) == 0 {
|
||||
_old = nil
|
||||
}
|
||||
|
||||
if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil {
|
||||
return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) deleteEAKID(ctx context.Context, provisionerID, eakID string) error {
|
||||
referencesByProvisionerIndexMutex.Lock()
|
||||
defer referencesByProvisionerIndexMutex.Unlock()
|
||||
|
||||
var eakIDs []string
|
||||
b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID))
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
// it may happen that no record is found; we'll continue with an empty slice
|
||||
} else {
|
||||
if err := json.Unmarshal(b, &eakIDs); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
newEAKIDs := removeElement(eakIDs, eakID)
|
||||
var (
|
||||
_old interface{} = eakIDs
|
||||
_new interface{} = newEAKIDs
|
||||
)
|
||||
|
||||
// ensure that the DB gets the expected value when the slice is empty; otherwise
|
||||
// it'll return with an error that indicates that the DBs view of the data is
|
||||
// different from the last read (i.e. _old is different from what the DB has).
|
||||
if len(eakIDs) == 0 {
|
||||
_old = nil
|
||||
}
|
||||
|
||||
if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil {
|
||||
return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// referenceKey returns a unique key for a reference per provisioner
|
||||
func referenceKey(provisionerID, reference string) string {
|
||||
return provisionerID + "." + reference
|
||||
}
|
||||
|
||||
// sliceIndex finds the index of item in slice
|
||||
func sliceIndex(slice []string, item string) int {
|
||||
for i := range slice {
|
||||
if slice[i] == item {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// removeElement deletes the item if it exists in the
|
||||
// slice. It returns a new slice, keeping the old one intact.
|
||||
func removeElement(slice []string, item string) []string {
|
||||
|
||||
newSlice := make([]string, 0)
|
||||
index := sliceIndex(slice, item)
|
||||
if index < 0 {
|
||||
newSlice = append(newSlice, slice...)
|
||||
return newSlice
|
||||
}
|
||||
|
||||
newSlice = append(newSlice, slice[:index]...)
|
||||
|
||||
return append(newSlice, slice[index+1:]...)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,246 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
|
||||
type CreateExternalAccountKeyRequest struct {
|
||||
Reference string `json:"reference"`
|
||||
}
|
||||
|
||||
// Validate validates a new ACME EAB Key request body.
|
||||
func (r *CreateExternalAccountKeyRequest) Validate() error {
|
||||
if len(r.Reference) > 256 { // an arbitrary, but sensible (IMO), limit
|
||||
return fmt.Errorf("reference length %d exceeds the maximum (256)", len(r.Reference))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKeysResponse is the type for GET /admin/acme/eab responses
|
||||
type GetExternalAccountKeysResponse struct {
|
||||
EAKs []*linkedca.EABKey `json:"eaks"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
|
||||
// before serving requests that act on ACME EAB credentials.
|
||||
func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
provName := chi.URLParam(r, "provisionerName")
|
||||
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if !eabEnabled {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
|
||||
// provisioner is set to true and thus has EAB enabled.
|
||||
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil {
|
||||
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
|
||||
}
|
||||
|
||||
prov, err := h.db.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
details := prov.GetDetails()
|
||||
if details == nil {
|
||||
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
acmeProvisioner := details.GetACME()
|
||||
if acmeProvisioner == nil {
|
||||
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
return acmeProvisioner.GetRequireEab(), prov, nil
|
||||
}
|
||||
|
||||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (*linkedca.Provisioner, error) {
|
||||
val := ctx.Value(provisionerContextKey)
|
||||
if val == nil {
|
||||
return nil, admin.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
pval, ok := val.(*linkedca.Provisioner)
|
||||
if !ok || pval == nil {
|
||||
return nil, admin.NewErrorISE("provisioner in context is not a linkedca.Provisioner")
|
||||
}
|
||||
return pval, nil
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey creates a new External Account Binding key
|
||||
func (h *Handler) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateExternalAccountKeyRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating request body"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error getting provisioner from context"))
|
||||
return
|
||||
}
|
||||
|
||||
// check if a key with the reference does not exist (only when a reference was in the request)
|
||||
reference := body.Reference
|
||||
if reference != "" {
|
||||
k, err := h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
||||
// retrieving an EAB key from DB results in an error if it doesn't exist, which is what we're looking for,
|
||||
// but other errors can also happen. Return early if that happens; continuing if it was acme.ErrNotFound.
|
||||
if shouldWriteError := err != nil && !errors.Is(err, acme.ErrNotFound); shouldWriteError {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "could not lookup external account key by reference"))
|
||||
return
|
||||
}
|
||||
// if a key was found, return HTTP 409 conflict
|
||||
if k != nil {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "an ACME EAB key for provisioner '%s' with reference '%s' already exists", prov.GetName(), reference)
|
||||
err.Status = 409
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
// continue execution if no key was found for the reference
|
||||
}
|
||||
|
||||
eak, err := h.acmeDB.CreateExternalAccountKey(ctx, prov.GetId(), reference)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("error creating ACME EAB key for provisioner '%s'", prov.GetName())
|
||||
if reference != "" {
|
||||
msg += fmt.Sprintf(" and reference '%s'", reference)
|
||||
}
|
||||
api.WriteError(w, admin.WrapErrorISE(err, msg))
|
||||
return
|
||||
}
|
||||
|
||||
response := &linkedca.EABKey{
|
||||
Id: eak.ID,
|
||||
HmacKey: eak.KeyBytes,
|
||||
Provisioner: prov.GetName(),
|
||||
Reference: eak.Reference,
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, response, http.StatusCreated)
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey deletes an ACME External Account Key.
|
||||
func (h *Handler) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
keyID := chi.URLParam(r, "id")
|
||||
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error getting provisioner from context"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.acmeDB.DeleteExternalAccountKey(ctx, prov.GetId(), keyID); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error deleting ACME EAB Key '%s'", keyID))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys returns ACME EAB Keys. If a reference is specified,
|
||||
// only the ExternalAccountKey with that reference is returned. Otherwise all
|
||||
// ExternalAccountKeys in the system for a specific provisioner are returned.
|
||||
func (h *Handler) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var (
|
||||
key *acme.ExternalAccountKey
|
||||
keys []*acme.ExternalAccountKey
|
||||
err error
|
||||
cursor string
|
||||
nextCursor string
|
||||
limit int
|
||||
)
|
||||
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error getting provisioner from context"))
|
||||
return
|
||||
}
|
||||
|
||||
if cursor, limit, err = api.ParseCursor(r); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor and limit from query params"))
|
||||
return
|
||||
}
|
||||
|
||||
reference := chi.URLParam(r, "reference")
|
||||
if reference != "" {
|
||||
if key, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving external account key with reference '%s'", reference))
|
||||
return
|
||||
}
|
||||
if key != nil {
|
||||
keys = []*acme.ExternalAccountKey{key}
|
||||
}
|
||||
} else {
|
||||
if keys, nextCursor, err = h.acmeDB.GetExternalAccountKeys(ctx, prov.GetId(), cursor, limit); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving external account keys"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
provisionerName := prov.GetName()
|
||||
eaks := make([]*linkedca.EABKey, len(keys))
|
||||
for i, k := range keys {
|
||||
eaks[i] = &linkedca.EABKey{
|
||||
Id: k.ID,
|
||||
HmacKey: []byte{},
|
||||
Provisioner: provisionerName,
|
||||
Reference: k.Reference,
|
||||
Account: k.AccountID,
|
||||
CreatedAt: timestamppb.New(k.CreatedAt),
|
||||
BoundAt: timestamppb.New(k.BoundAt),
|
||||
}
|
||||
}
|
||||
|
||||
api.JSON(w, &GetExternalAccountKeysResponse{
|
||||
EAKs: eaks,
|
||||
NextCursor: nextCursor,
|
||||
})
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,919 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type mockAdminAuthority struct {
|
||||
MockLoadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
MockGetProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two
|
||||
MockErr error
|
||||
MockIsAdminAPIEnabled func() bool
|
||||
MockLoadAdminByID func(id string) (*linkedca.Admin, bool)
|
||||
MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error)
|
||||
MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error
|
||||
MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error)
|
||||
MockRemoveAdmin func(ctx context.Context, id string) error
|
||||
MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error)
|
||||
MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
MockRemoveProvisioner func(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
|
||||
if m.MockIsAdminAPIEnabled != nil {
|
||||
return m.MockIsAdminAPIEnabled()
|
||||
}
|
||||
return m.MockRet1.(bool)
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.MockLoadProvisionerByName != nil {
|
||||
return m.MockLoadProvisionerByName(name)
|
||||
}
|
||||
return m.MockRet1.(provisioner.Interface), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
|
||||
if m.MockGetProvisioners != nil {
|
||||
return m.MockGetProvisioners(nextCursor, limit)
|
||||
}
|
||||
return m.MockRet1.(provisioner.List), m.MockRet2.(string), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||
if m.MockLoadAdminByID != nil {
|
||||
return m.MockLoadAdminByID(id)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool)
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) {
|
||||
if m.MockGetAdmins != nil {
|
||||
return m.MockGetAdmins(cursor, limit)
|
||||
}
|
||||
return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
if m.MockStoreAdmin != nil {
|
||||
return m.MockStoreAdmin(ctx, adm, prov)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
if m.MockUpdateAdmin != nil {
|
||||
return m.MockUpdateAdmin(ctx, id, nu)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) RemoveAdmin(ctx context.Context, id string) error {
|
||||
if m.MockRemoveAdmin != nil {
|
||||
return m.MockRemoveAdmin(ctx, id)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
if m.MockAuthorizeAdminToken != nil {
|
||||
return m.MockAuthorizeAdminToken(r, token)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
if m.MockStoreProvisioner != nil {
|
||||
return m.MockStoreProvisioner(ctx, prov)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
if m.MockLoadProvisionerByID != nil {
|
||||
return m.MockLoadProvisionerByID(id)
|
||||
}
|
||||
return m.MockRet1.(provisioner.Interface), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
if m.MockUpdateProvisioner != nil {
|
||||
return m.MockUpdateProvisioner(ctx, nu)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) error {
|
||||
if m.MockRemoveProvisioner != nil {
|
||||
return m.MockRemoveProvisioner(ctx, id)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func TestCreateAdminRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Subject string
|
||||
Provisioner string
|
||||
Type linkedca.Admin_Type
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
err *admin.Error
|
||||
}{
|
||||
{
|
||||
name: "fail/subject-empty",
|
||||
fields: fields{
|
||||
Subject: "",
|
||||
Provisioner: "",
|
||||
Type: 0,
|
||||
},
|
||||
err: admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty"),
|
||||
},
|
||||
{
|
||||
name: "fail/provisioner-empty",
|
||||
fields: fields{
|
||||
Subject: "admin",
|
||||
Provisioner: "",
|
||||
Type: 0,
|
||||
},
|
||||
err: admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty"),
|
||||
},
|
||||
{
|
||||
name: "fail/invalid-type",
|
||||
fields: fields{
|
||||
Subject: "admin",
|
||||
Provisioner: "prov",
|
||||
Type: -1,
|
||||
},
|
||||
err: admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type"),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
Subject: "admin",
|
||||
Provisioner: "prov",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
car := &CreateAdminRequest{
|
||||
Subject: tt.fields.Subject,
|
||||
Provisioner: tt.fields.Provisioner,
|
||||
Type: tt.fields.Type,
|
||||
}
|
||||
err := car.Validate()
|
||||
|
||||
if (err != nil) != (tt.err != nil) {
|
||||
t.Errorf("CreateAdminRequest.Validate() error = %v, wantErr %v", err, (tt.err != nil))
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
assert.Type(t, &admin.Error{}, err)
|
||||
adminErr, _ := err.(*admin.Error)
|
||||
assert.Equals(t, tt.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tt.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, tt.err.Status, adminErr.Status)
|
||||
assert.Equals(t, tt.err.Message, adminErr.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAdminRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Type linkedca.Admin_Type
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
err *admin.Error
|
||||
}{
|
||||
{
|
||||
name: "fail/invalid-type",
|
||||
fields: fields{
|
||||
Type: -1,
|
||||
},
|
||||
err: admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type"),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
uar := &UpdateAdminRequest{
|
||||
Type: tt.fields.Type,
|
||||
}
|
||||
|
||||
err := uar.Validate()
|
||||
|
||||
if (err != nil) != (tt.err != nil) {
|
||||
t.Errorf("CreateAdminRequest.Validate() error = %v, wantErr %v", err, (tt.err != nil))
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
assert.Type(t, &admin.Error{}, err)
|
||||
adminErr, _ := err.(*admin.Error)
|
||||
assert.Equals(t, tt.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tt.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, tt.err.Status, adminErr.Status)
|
||||
assert.Equals(t, tt.err.Message, adminErr.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_GetAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
statusCode int
|
||||
err *admin.Error
|
||||
adm *linkedca.Admin
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.LoadAdminByID-not-found": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("id", "adminID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) {
|
||||
assert.Equals(t, "adminID", id)
|
||||
return nil, false
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
statusCode: 404,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorNotFoundType.String(),
|
||||
Status: 404,
|
||||
Detail: "resource not found",
|
||||
Message: "admin adminID not found",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("id", "adminID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
adm := &linkedca.Admin{
|
||||
Id: "adminID",
|
||||
AuthorityId: "authorityID",
|
||||
Subject: "admin",
|
||||
ProvisionerId: "provID",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
DeletedAt: timestamppb.New(deletedAt),
|
||||
}
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) {
|
||||
assert.Equals(t, "adminID", id)
|
||||
return adm, true
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
statusCode: 200,
|
||||
err: nil,
|
||||
adm: adm,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
adminErr := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
adm := &linkedca.Admin{}
|
||||
err := readProtoJSON(res.Body, adm)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(tc.adm, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.adm, adm, opts...))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_GetAdmins(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
req *http.Request
|
||||
statusCode int
|
||||
err *admin.Error
|
||||
resp GetAdminsResponse
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/parse-cursor": func(t *testing.T) test {
|
||||
req := httptest.NewRequest("GET", "/foo?limit=A", nil)
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
req: req,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Status: 400,
|
||||
Type: admin.ErrorBadRequestType.String(),
|
||||
Detail: "bad request",
|
||||
Message: "error parsing cursor and limit from query params: limit 'A' is not an integer: strconv.Atoi: parsing \"A\": invalid syntax",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.GetAdmins": func(t *testing.T) test {
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
auth := &mockAdminAuthority{
|
||||
MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) {
|
||||
assert.Equals(t, "", cursor)
|
||||
assert.Equals(t, 0, limit)
|
||||
return nil, "", errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
req: req,
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Status: 500,
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
Detail: "the server experienced an internal error",
|
||||
Message: "error retrieving paginated admins: force",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
adm1 := &linkedca.Admin{
|
||||
Id: "adminID1",
|
||||
AuthorityId: "authorityID1",
|
||||
Subject: "admin1",
|
||||
ProvisionerId: "provID",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
DeletedAt: timestamppb.New(deletedAt),
|
||||
}
|
||||
adm2 := &linkedca.Admin{
|
||||
Id: "adminID2",
|
||||
AuthorityId: "authorityID",
|
||||
Subject: "admin2",
|
||||
ProvisionerId: "provID",
|
||||
Type: linkedca.Admin_ADMIN,
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
DeletedAt: timestamppb.New(deletedAt),
|
||||
}
|
||||
auth := &mockAdminAuthority{
|
||||
MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) {
|
||||
assert.Equals(t, "", cursor)
|
||||
assert.Equals(t, 0, limit)
|
||||
return []*linkedca.Admin{
|
||||
adm1,
|
||||
adm2,
|
||||
}, "nextCursorValue", nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
req: req,
|
||||
auth: auth,
|
||||
statusCode: 200,
|
||||
err: nil,
|
||||
resp: GetAdminsResponse{
|
||||
Admins: []*linkedca.Admin{
|
||||
adm1,
|
||||
adm2,
|
||||
},
|
||||
NextCursor: "nextCursorValue",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmins(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
adminErr := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
response := GetAdminsResponse{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response))
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(tc.resp, response, opts...) {
|
||||
t.Errorf("GetAdmins diff =\n%s", cmp.Diff(tc.resp, response, opts...))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_CreateAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
body []byte
|
||||
statusCode int
|
||||
err *admin.Error
|
||||
adm *linkedca.Admin
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/ReadJSON": func(t *testing.T) test {
|
||||
body := []byte("{!?}")
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorBadRequestType.String(),
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/validate": func(t *testing.T) test {
|
||||
req := CreateAdminRequest{
|
||||
Subject: "",
|
||||
Provisioner: "",
|
||||
Type: -1,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorBadRequestType.String(),
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "subject cannot be empty",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||
req := CreateAdminRequest{
|
||||
Subject: "admin",
|
||||
Provisioner: "prov",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "prov", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
Status: 500,
|
||||
Detail: "the server experienced an internal error",
|
||||
Message: "error loading provisioner prov: force",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.StoreAdmin": func(t *testing.T) test {
|
||||
req := CreateAdminRequest{
|
||||
Subject: "admin",
|
||||
Provisioner: "prov",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "prov", name)
|
||||
return &provisioner.ACME{
|
||||
ID: "provID",
|
||||
Name: "prov",
|
||||
}, nil
|
||||
},
|
||||
MockStoreAdmin: func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
assert.Equals(t, "admin", adm.Subject)
|
||||
assert.Equals(t, "provID", prov.GetID())
|
||||
return errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
Status: 500,
|
||||
Detail: "the server experienced an internal error",
|
||||
Message: "error storing admin: force",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
req := CreateAdminRequest{
|
||||
Subject: "admin",
|
||||
Provisioner: "prov",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "prov", name)
|
||||
return &provisioner.ACME{
|
||||
ID: "provID",
|
||||
Name: "prov",
|
||||
}, nil
|
||||
},
|
||||
MockStoreAdmin: func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
assert.Equals(t, "admin", adm.Subject)
|
||||
assert.Equals(t, "provID", prov.GetID())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
auth: auth,
|
||||
statusCode: 201,
|
||||
err: nil,
|
||||
adm: &linkedca.Admin{
|
||||
ProvisionerId: "provID",
|
||||
Subject: "admin",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
adminErr := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
adm := &linkedca.Admin{}
|
||||
err := readProtoJSON(res.Body, adm)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(tc.adm, adm, opts...) {
|
||||
t.Errorf("h.CreateAdmin diff =\n%s", cmp.Diff(tc.adm, adm, opts...))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_DeleteAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
statusCode int
|
||||
err *admin.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.RemoveAdmin": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("id", "adminID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockRemoveAdmin: func(ctx context.Context, id string) error {
|
||||
assert.Equals(t, "adminID", id)
|
||||
return errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
Status: 500,
|
||||
Detail: "the server experienced an internal error",
|
||||
Message: "error deleting admin adminID: force",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("id", "adminID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockRemoveAdmin: func(ctx context.Context, id string) error {
|
||||
assert.Equals(t, "adminID", id)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
statusCode: 200,
|
||||
err: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteAdmin(w, req)
|
||||
res := w.Result()
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
adminErr := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
response := DeleteResponse{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response))
|
||||
assert.Equals(t, "ok", response.Status)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_UpdateAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
body []byte
|
||||
statusCode int
|
||||
err *admin.Error
|
||||
adm *linkedca.Admin
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/ReadJSON": func(t *testing.T) test {
|
||||
body := []byte("{!?}")
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorBadRequestType.String(),
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/validate": func(t *testing.T) test {
|
||||
req := UpdateAdminRequest{
|
||||
Type: -1,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorBadRequestType.String(),
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "invalid value for admin type",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.UpdateAdmin": func(t *testing.T) test {
|
||||
req := UpdateAdminRequest{
|
||||
Type: linkedca.Admin_ADMIN,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("id", "adminID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "adminID", id)
|
||||
assert.Equals(t, linkedca.Admin_ADMIN, nu.Type)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
Status: 500,
|
||||
Detail: "the server experienced an internal error",
|
||||
Message: "error updating admin adminID: force",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
req := UpdateAdminRequest{
|
||||
Type: linkedca.Admin_ADMIN,
|
||||
}
|
||||
body, err := json.Marshal(req)
|
||||
assert.FatalError(t, err)
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("id", "adminID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
adm := &linkedca.Admin{
|
||||
Id: "adminID",
|
||||
ProvisionerId: "provID",
|
||||
Subject: "admin",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
}
|
||||
auth := &mockAdminAuthority{
|
||||
MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "adminID", id)
|
||||
assert.Equals(t, linkedca.Admin_ADMIN, nu.Type)
|
||||
return adm, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
auth: auth,
|
||||
statusCode: 200,
|
||||
err: nil,
|
||||
adm: adm,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
adminErr := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
adm := &linkedca.Admin{}
|
||||
err := readProtoJSON(res.Body, adm)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(tc.adm, adm, opts...) {
|
||||
t.Errorf("h.UpdateAdmin diff =\n%s", cmp.Diff(tc.adm, adm, opts...))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,225 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
next nextHTTP
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.IsAdminAPIEnabled": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
auth: &mockAdminAuthority{
|
||||
MockIsAdminAPIEnabled: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorNotImplementedType.String(),
|
||||
Status: 501,
|
||||
Detail: "not implemented",
|
||||
Message: "administration API not enabled",
|
||||
},
|
||||
statusCode: 501,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockIsAdminAPIEnabled: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
auth: auth,
|
||||
next: next,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireAPIEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
// nothing to test when the requireAPIEnabled middleware succeeds, currently
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
req *http.Request
|
||||
next nextHTTP
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/missing-authorization-token": func(t *testing.T) test {
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Header["Authorization"] = []string{""}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
req: req,
|
||||
statusCode: 401,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorUnauthorizedType.String(),
|
||||
Status: 401,
|
||||
Detail: "unauthorized",
|
||||
Message: "missing authorization header token",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.AuthorizeAdminToken": func(t *testing.T) test {
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Header["Authorization"] = []string{"token"}
|
||||
auth := &mockAdminAuthority{
|
||||
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "token", token)
|
||||
return nil, admin.NewError(
|
||||
admin.ErrorUnauthorizedType,
|
||||
"not authorized",
|
||||
)
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
auth: auth,
|
||||
req: req,
|
||||
statusCode: 401,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorUnauthorizedType.String(),
|
||||
Status: 401,
|
||||
Detail: "unauthorized",
|
||||
Message: "not authorized",
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Header["Authorization"] = []string{"token"}
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
admin := &linkedca.Admin{
|
||||
Id: "adminID",
|
||||
AuthorityId: "authorityID",
|
||||
Subject: "admin",
|
||||
ProvisionerId: "provID",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
DeletedAt: timestamppb.New(deletedAt),
|
||||
}
|
||||
auth := &mockAdminAuthority{
|
||||
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "token", token)
|
||||
return admin, nil
|
||||
},
|
||||
}
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
|
||||
adm, ok := a.(*linkedca.Admin)
|
||||
if !ok {
|
||||
t.Errorf("expected *linkedca.Admin; got %T", a)
|
||||
return
|
||||
}
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(admin, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
|
||||
}
|
||||
w.Write(nil) // mock response with status 200
|
||||
}
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
auth: auth,
|
||||
req: req,
|
||||
next: next,
|
||||
statusCode: 200,
|
||||
err: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue