|
|
|
@ -8,14 +8,17 @@ import (
|
|
|
|
|
"encoding/pem"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
"net/url"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/go-chi/chi"
|
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
|
"github.com/smallstep/assert"
|
|
|
|
|
"github.com/smallstep/certificates/acme"
|
|
|
|
|
"go.step.sm/crypto/jose"
|
|
|
|
|
"go.step.sm/crypto/pemutil"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -438,16 +441,21 @@ func ch() acme.Challenge {
|
|
|
|
|
func TestHandler_GetChallenge(t *testing.T) {
|
|
|
|
|
chiCtx := chi.NewRouteContext()
|
|
|
|
|
chiCtx.URLParams.Add("chID", "chID")
|
|
|
|
|
chiCtx.URLParams.Add("authzID", "authzID")
|
|
|
|
|
prov := newProv()
|
|
|
|
|
provName := url.PathEscape(prov.GetName())
|
|
|
|
|
|
|
|
|
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
|
|
|
|
url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID")
|
|
|
|
|
|
|
|
|
|
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
|
|
|
|
baseURL.String(), provName, "authzID", "chID")
|
|
|
|
|
|
|
|
|
|
type test struct {
|
|
|
|
|
db acme.DB
|
|
|
|
|
vco *acme.ValidateChallengeOptions
|
|
|
|
|
ctx context.Context
|
|
|
|
|
statusCode int
|
|
|
|
|
ch acme.Challenge
|
|
|
|
|
ch *acme.Challenge
|
|
|
|
|
err *acme.Error
|
|
|
|
|
}
|
|
|
|
|
var tests = map[string]func(t *testing.T) test{
|
|
|
|
@ -485,8 +493,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|
|
|
|
err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
/*
|
|
|
|
|
"fail/validate-challenge-error": func(t *testing.T) test {
|
|
|
|
|
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
|
|
|
|
acc := &acme.Account{ID: "accID"}
|
|
|
|
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
|
|
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
|
|
|
@ -494,75 +501,169 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
|
return test{
|
|
|
|
|
db: &acme.MockDB{
|
|
|
|
|
MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"),
|
|
|
|
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
|
|
|
|
assert.Equals(t, chID, "chID")
|
|
|
|
|
assert.Equals(t, azID, "authzID")
|
|
|
|
|
return nil, acme.NewErrorISE("force")
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
statusCode: 401,
|
|
|
|
|
err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"),
|
|
|
|
|
statusCode: 500,
|
|
|
|
|
err: acme.NewErrorISE("force"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"fail/get-challenge-error": func(t *testing.T) test {
|
|
|
|
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
|
|
|
|
acc := &acme.Account{ID: "accID"}
|
|
|
|
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
|
|
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
|
|
|
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
|
|
|
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
|
return test{
|
|
|
|
|
db: &acme.MockDB{
|
|
|
|
|
MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"),
|
|
|
|
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
|
|
|
|
assert.Equals(t, chID, "chID")
|
|
|
|
|
assert.Equals(t, azID, "authzID")
|
|
|
|
|
return &acme.Challenge{AccountID: "foo"}, nil
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
statusCode: 401,
|
|
|
|
|
err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"),
|
|
|
|
|
err: acme.NewError(acme.ErrorUnauthorizedType, "accout id mismatch"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok/validate-challenge": func(t *testing.T) test {
|
|
|
|
|
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
acc := &acme.Account{ID: "accID", Key: key}
|
|
|
|
|
"fail/no-jwk": func(t *testing.T) test {
|
|
|
|
|
acc := &acme.Account{ID: "accID"}
|
|
|
|
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
|
|
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
|
|
|
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
|
return test{
|
|
|
|
|
db: &acme.MockDB{
|
|
|
|
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
|
|
|
|
assert.Equals(t, chID, "chID")
|
|
|
|
|
assert.Equals(t, azID, "authzID")
|
|
|
|
|
return &acme.Challenge{AccountID: "accID"}, nil
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
statusCode: 500,
|
|
|
|
|
err: acme.NewErrorISE("missing jwk"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"fail/nil-jwk": func(t *testing.T) test {
|
|
|
|
|
acc := &acme.Account{ID: "accID"}
|
|
|
|
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
|
|
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
|
|
|
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
|
|
|
|
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
|
return test{
|
|
|
|
|
db: &acme.MockDB{
|
|
|
|
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
|
|
|
|
assert.Equals(t, chID, "chID")
|
|
|
|
|
assert.Equals(t, azID, "authzID")
|
|
|
|
|
return &acme.Challenge{AccountID: "accID"}, nil
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
statusCode: 500,
|
|
|
|
|
err: acme.NewErrorISE("nil jwk"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"fail/validate-challenge-error": func(t *testing.T) test {
|
|
|
|
|
acc := &acme.Account{ID: "accID"}
|
|
|
|
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
|
|
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
|
|
|
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
|
|
|
|
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
_pub := _jwk.Public()
|
|
|
|
|
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
|
return test{
|
|
|
|
|
db: &acme.MockDB{
|
|
|
|
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
|
|
|
|
assert.Equals(t, chID, "chID")
|
|
|
|
|
assert.Equals(t, azID, "authzID")
|
|
|
|
|
return &acme.Challenge{
|
|
|
|
|
Status: acme.StatusPending,
|
|
|
|
|
Type: "http-01",
|
|
|
|
|
AccountID: "accID",
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
|
|
|
|
assert.Equals(t, ch.Status, acme.StatusPending)
|
|
|
|
|
assert.Equals(t, ch.Type, "http-01")
|
|
|
|
|
assert.Equals(t, ch.AccountID, "accID")
|
|
|
|
|
assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
|
|
|
|
|
return acme.NewErrorISE("force")
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
vco: &acme.ValidateChallengeOptions{
|
|
|
|
|
HTTPGet: func(string) (*http.Response, error) {
|
|
|
|
|
return nil, errors.New("force")
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
statusCode: 500,
|
|
|
|
|
err: acme.NewErrorISE("force"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok": func(t *testing.T) test {
|
|
|
|
|
acc := &acme.Account{ID: "accID"}
|
|
|
|
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
|
|
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
|
|
|
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
|
|
|
|
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
_pub := _jwk.Public()
|
|
|
|
|
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
|
|
|
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
|
|
|
ch := ch()
|
|
|
|
|
ch.Status = "valid"
|
|
|
|
|
ch.Validated = time.Now().UTC().Format(time.RFC3339)
|
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
|
return test{
|
|
|
|
|
db: &acme.MockDB{
|
|
|
|
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
|
|
|
|
assert.Equals(t, chID, ch.ID)
|
|
|
|
|
return &ch, nil
|
|
|
|
|
},
|
|
|
|
|
getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
|
|
|
|
|
var ret string
|
|
|
|
|
switch count {
|
|
|
|
|
case 0:
|
|
|
|
|
assert.Equals(t, typ, acme.AuthzLink)
|
|
|
|
|
assert.True(t, abs)
|
|
|
|
|
assert.Equals(t, in, []string{ch.AuthzID})
|
|
|
|
|
ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID)
|
|
|
|
|
case 1:
|
|
|
|
|
assert.Equals(t, typ, acme.ChallengeLink)
|
|
|
|
|
assert.True(t, abs)
|
|
|
|
|
assert.Equals(t, in, []string{ch.ID})
|
|
|
|
|
ret = url
|
|
|
|
|
}
|
|
|
|
|
count++
|
|
|
|
|
return ret
|
|
|
|
|
assert.Equals(t, chID, "chID")
|
|
|
|
|
assert.Equals(t, azID, "authzID")
|
|
|
|
|
return &acme.Challenge{
|
|
|
|
|
ID: "chID",
|
|
|
|
|
AuthzID: "authzID",
|
|
|
|
|
Status: acme.StatusPending,
|
|
|
|
|
Type: "http-01",
|
|
|
|
|
AccountID: "accID",
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
|
|
|
|
assert.Equals(t, ch.Status, acme.StatusPending)
|
|
|
|
|
assert.Equals(t, ch.Type, "http-01")
|
|
|
|
|
assert.Equals(t, ch.AccountID, "accID")
|
|
|
|
|
assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
|
|
|
|
|
return nil
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ch: &acme.Challenge{
|
|
|
|
|
ID: "chID",
|
|
|
|
|
AuthzID: "authzID",
|
|
|
|
|
Status: acme.StatusPending,
|
|
|
|
|
Type: "http-01",
|
|
|
|
|
AccountID: "accID",
|
|
|
|
|
URL: url,
|
|
|
|
|
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
|
|
|
|
},
|
|
|
|
|
vco: &acme.ValidateChallengeOptions{
|
|
|
|
|
HTTPGet: func(string) (*http.Response, error) {
|
|
|
|
|
return nil, errors.New("force")
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
statusCode: 200,
|
|
|
|
|
ch: ch,
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
*/
|
|
|
|
|
}
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
|
|
|
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
|
|
|
|
req := httptest.NewRequest("GET", url, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|