diff --git a/scep/api/api.go b/scep/api/api.go index 96e25104..f6e1b1ce 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -305,14 +305,21 @@ func PKIOperation(ctx context.Context, req request) (Response, error) { return Response{}, err } - // NOTE: at this point we have sufficient information for returning nicely signed CertReps - csr := msg.CSRReqMessage.CSR - prov, err := scep.ProvisionerFromContext(ctx) if err != nil { return Response{}, err } + scepProv, ok := prov.(*provisioner.SCEP) + if !ok { + return Response{}, errors.New("wrong type of provisioner in context") + } + + // NOTE: at this point we have sufficient information for returning nicely signed CertReps + csr := msg.CSRReqMessage.CSR + transactionID := string(msg.TransactionID) + challengePassword := msg.CSRReqMessage.ChallengePassword + // NOTE: we're blocking the RenewalReq if the challenge does not match, because otherwise we don't have any authentication. // The macOS SCEP client performs renewals using PKCSreq. The CertNanny SCEP client will use PKCSreq with challenge too, it seems, // even if using the renewal flow as described in the README.md. MicroMDM SCEP client also only does PKCSreq by default, unless @@ -323,22 +330,22 @@ func PKIOperation(ctx context.Context, req request) (Response, error) { // auth.MatchChallengePassword interface/method. Will need to think about methods // that don't just check the password, but do different things on success and // failure too. - switch selectValidationMethod(prov) { + switch selectValidationMethod(scepProv) { case validationMethodWebhook: - c, err := webhook.New(prov.GetOptions().GetWebhooks()) + c, err := webhook.New(scepProv.GetOptions().GetWebhooks()) if err != nil { return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("failed creating SCEP validation webhook controller")) } - if err := c.Validate(ctx, msg.CSRReqMessage.ChallengePassword); err != nil { + if err := c.Validate(ctx, challengePassword, transactionID); err != nil { if errors.Is(err, provisioner.ErrWebhookDenied) { return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("invalid challenge password provided")) } return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("failed validating challenge password")) } default: - challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) + challengeMatches, err := auth.MatchChallengePassword(ctx, challengePassword) if err != nil { - return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("failed checking password")) } if !challengeMatches { // TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too. @@ -372,6 +379,7 @@ func PKIOperation(ctx context.Context, req request) (Response, error) { type validationMethod string const ( + validationMethodNone validationMethod = "none" validationMethodStatic validationMethod = "static" validationMethodWebhook validationMethod = "webhook" ) @@ -380,15 +388,19 @@ const ( // challenges. If a webhook is configured with kind `SCEPCHALLENGE`, // the webhook will be used. Otherwise it will default to the // static challenge value. -func selectValidationMethod(p scep.Provisioner) validationMethod { +func selectValidationMethod(p *provisioner.SCEP) validationMethod { for _, wh := range p.GetOptions().GetWebhooks() { - // if there's at least one webhook for validating SCEP challenges, the - // webhook will be used to perform challenge validation. + // if at least one webhook for validating SCEP challenges has + // been configured, that will be used to perform challenge + // validation. if wh.Kind == linkedca.Webhook_SCEPCHALLENGE.String() { return validationMethodWebhook } } - return validationMethodStatic + if challenge := p.GetChallengePassword(); challenge != "" { + return validationMethodStatic + } + return validationMethodNone } func formatCapabilities(caps []string) []byte { diff --git a/scep/api/api_test.go b/scep/api/api_test.go index bdb51594..ee53d25e 100644 --- a/scep/api/api_test.go +++ b/scep/api/api_test.go @@ -9,6 +9,12 @@ import ( "reflect" "testing" "testing/iotest" + + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/linkedca" ) func Test_decodeRequest(t *testing.T) { @@ -111,3 +117,47 @@ func Test_decodeRequest(t *testing.T) { }) } } + +func Test_selectValidationMethod(t *testing.T) { + tests := []struct { + name string + p *provisioner.SCEP + want validationMethod + }{ + {"webhooks", &provisioner.SCEP{ + Name: "SCEP", + Type: "SCEP", + Options: &provisioner.Options{ + Webhooks: []*provisioner.Webhook{ + { + Kind: linkedca.Webhook_SCEPCHALLENGE.String(), + }, + }, + }, + Claims: &provisioner.Claims{}, + }, "webhook"}, + {"challenge", &provisioner.SCEP{ + Name: "SCEP", + Type: "SCEP", + ChallengePassword: "pass", + Options: &provisioner.Options{}, + Claims: &provisioner.Claims{}, + }, "static"}, + {"none", &provisioner.SCEP{ + Name: "SCEP", + Type: "SCEP", + Options: &provisioner.Options{}, + Claims: &provisioner.Claims{}, + }, "none"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.p.Init(provisioner.Config{ + Claims: config.GlobalProvisionerClaims, + }) + require.NoError(t, err) + got := selectValidationMethod(tt.p) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/scep/api/webhook/webhook.go b/scep/api/webhook/webhook.go index b191c426..dbaa5749 100644 --- a/scep/api/webhook/webhook.go +++ b/scep/api/webhook/webhook.go @@ -31,7 +31,7 @@ func New(webhooks []*provisioner.Webhook) (*Controller, error) { // webhooks will not be executed. If none of the webhooks // indicates the challenge is accepted, an error is // returned. -func (c *Controller) Validate(ctx context.Context, challenge string) error { +func (c *Controller) Validate(ctx context.Context, challenge, transactionID string) error { for _, wh := range c.webhooks { if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() { continue @@ -40,7 +40,8 @@ func (c *Controller) Validate(ctx context.Context, challenge string) error { continue } req := &webhook.RequestBody{ - SCEPChallenge: challenge, + SCEPChallenge: challenge, + SCEPTransactionID: transactionID, } resp, err := wh.DoWithContext(ctx, c.client, req, nil) // TODO(hs): support templated URL? Requires some refactoring if err != nil { diff --git a/scep/api/webhook/webhook_test.go b/scep/api/webhook/webhook_test.go new file mode 100644 index 00000000..5d8012ac --- /dev/null +++ b/scep/api/webhook/webhook_test.go @@ -0,0 +1,176 @@ +package webhook + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/provisioner" +) + +func TestController_Validate(t *testing.T) { + type request struct { + Challenge string `json:"scepChallenge"` + TransactionID string `json:"scepTransactionID"` + } + type response struct { + Allow bool `json:"allow"` + } + nokServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req := &request{} + err := json.NewDecoder(r.Body).Decode(req) + require.NoError(t, err) + assert.Equal(t, "not-allowed", req.Challenge) + assert.Equal(t, "transaction-1", req.TransactionID) + b, err := json.Marshal(response{Allow: false}) + require.NoError(t, err) + w.WriteHeader(200) + w.Write(b) + })) + okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req := &request{} + err := json.NewDecoder(r.Body).Decode(req) + require.NoError(t, err) + assert.Equal(t, "challenge", req.Challenge) + assert.Equal(t, "transaction-1", req.TransactionID) + b, err := json.Marshal(response{Allow: true}) + require.NoError(t, err) + w.WriteHeader(200) + w.Write(b) + })) + type fields struct { + client *http.Client + webhooks []*provisioner.Webhook + } + type args struct { + challenge string + transactionID string + } + tests := []struct { + name string + fields fields + args args + server *httptest.Server + expErr error + }{ + { + name: "fail/no-webhook", + fields: fields{http.DefaultClient, nil}, + args: args{"no-webhook", "transaction-1"}, + expErr: errors.New("webhook server did not allow request"), + }, + { + name: "fail/no-scep-webhook", + fields: fields{http.DefaultClient, []*provisioner.Webhook{ + { + Kind: linkedca.Webhook_AUTHORIZING.String(), + }, + }}, + args: args{"no-scep-webhook", "transaction-1"}, + expErr: errors.New("webhook server did not allow request"), + }, + { + name: "fail/wrong-cert-type", + fields: fields{http.DefaultClient, []*provisioner.Webhook{ + { + Kind: linkedca.Webhook_SCEPCHALLENGE.String(), + CertType: linkedca.Webhook_SSH.String(), + }, + }}, + args: args{"wrong-cert-type", "transaction-1"}, + expErr: errors.New("webhook server did not allow request"), + }, + { + name: "fail/wrong-secret-value", + fields: fields{http.DefaultClient, []*provisioner.Webhook{ + { + ID: "webhook-id-1", + Name: "webhook-name-1", + Secret: "{{}}", + Kind: linkedca.Webhook_SCEPCHALLENGE.String(), + CertType: linkedca.Webhook_X509.String(), + URL: okServer.URL, + }, + }}, + args: args{ + challenge: "wrong-secret-value", + transactionID: "transaction-1", + }, + expErr: errors.New("failed executing webhook request: illegal base64 data at input byte 0"), + }, + { + name: "fail/not-allowed", + fields: fields{http.DefaultClient, []*provisioner.Webhook{ + { + ID: "webhook-id-1", + Name: "webhook-name-1", + Secret: "MTIzNAo=", + Kind: linkedca.Webhook_SCEPCHALLENGE.String(), + CertType: linkedca.Webhook_X509.String(), + URL: nokServer.URL, + }, + }}, + args: args{ + challenge: "not-allowed", + transactionID: "transaction-1", + }, + server: nokServer, + expErr: errors.New("webhook server did not allow request"), + }, + { + name: "ok", + fields: fields{http.DefaultClient, []*provisioner.Webhook{ + { + ID: "webhook-id-1", + Name: "webhook-name-1", + Secret: "MTIzNAo=", + Kind: linkedca.Webhook_SCEPCHALLENGE.String(), + CertType: linkedca.Webhook_X509.String(), + URL: okServer.URL, + }, + }}, + args: args{ + challenge: "challenge", + transactionID: "transaction-1", + }, + server: okServer, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + client: tt.fields.client, + webhooks: tt.fields.webhooks, + } + + if tt.server != nil { + defer tt.server.Close() + } + + ctx := context.Background() + err := c.Validate(ctx, tt.args.challenge, tt.args.transactionID) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + }) + } +} + +func TestController_isCertTypeOK(t *testing.T) { + c := &Controller{} + assert.True(t, c.isCertTypeOK(&provisioner.Webhook{CertType: linkedca.Webhook_X509.String()})) + assert.True(t, c.isCertTypeOK(&provisioner.Webhook{CertType: linkedca.Webhook_ALL.String()})) + assert.True(t, c.isCertTypeOK(&provisioner.Webhook{CertType: ""})) + assert.False(t, c.isCertTypeOK(&provisioner.Webhook{CertType: linkedca.Webhook_SSH.String()})) +} diff --git a/webhook/types.go b/webhook/types.go index a1e10efe..9605742a 100644 --- a/webhook/types.go +++ b/webhook/types.go @@ -68,6 +68,7 @@ type RequestBody struct { X509Certificate *X509Certificate `json:"x509Certificate,omitempty"` SSHCertificateRequest *SSHCertificateRequest `json:"sshCertificateRequest,omitempty"` SSHCertificate *SSHCertificate `json:"sshCertificate,omitempty"` - // Only set for SCEP requests - SCEPChallenge string `json:"scepChallenge,omitempty"` + // Only set for SCEP challenge validation requests + SCEPChallenge string `json:"scepChallenge,omitempty"` + SCEPTransactionID string `json:"scepTransactionID,omitempty"` }