package api import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-chi/chi" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/linkedca" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) func readProtoJSON(r io.ReadCloser, m proto.Message) error { defer r.Close() data, err := io.ReadAll(r) if err != nil { return err } return protojson.Unmarshal(data, m) } func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context adminDB admin.DB auth adminAuthority next nextHTTP err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/h.provisionerHasEABEnabled": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") }, } err := admin.NewErrorISE("error loading provisioner provName: force") err.Message = "error loading provisioner provName: force" return test{ ctx: ctx, auth: auth, err: err, statusCode: 500, } }, "ok/eab-disabled": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: false, }, }, }, }, nil }, } err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") err.Message = "ACME EAB not enabled for provisioner provName" return test{ ctx: ctx, auth: auth, adminDB: db, err: err, statusCode: 400, } }, "ok/eab-enabled": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: true, }, }, }, }, nil }, } return test{ ctx: ctx, auth: auth, adminDB: db, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { h := &Handler{ auth: tc.auth, adminDB: tc.adminDB, acmeDB: nil, } req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } func TestHandler_provisionerHasEABEnabled(t *testing.T) { type test struct { adminDB admin.DB auth adminAuthority provisionerName string want bool err *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/auth.LoadProvisionerByName": func(t *testing.T) test { auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") }, } return test{ auth: auth, provisionerName: "provName", want: false, err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), } }, "fail/db.GetProvisioner": func(t *testing.T) test { auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") }, } return test{ auth: auth, adminDB: db, provisionerName: "provName", want: false, err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), } }, "fail/prov.GetDetails": func(t *testing.T) test { auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: nil, }, nil }, } return test{ auth: auth, adminDB: db, provisionerName: "provName", want: false, err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), } }, "fail/details.GetACME": func(t *testing.T) test { auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: nil, }, }, }, nil }, } return test{ auth: auth, adminDB: db, provisionerName: "provName", want: false, err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), } }, "ok/eab-disabled": func(t *testing.T) test { auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "eab-disabled", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "eab-disabled", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: false, }, }, }, }, nil }, } return test{ adminDB: db, auth: auth, provisionerName: "eab-disabled", want: false, } }, "ok/eab-enabled": func(t *testing.T) test { auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "eab-enabled", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "eab-enabled", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: true, }, }, }, }, nil }, } return test{ adminDB: db, auth: auth, provisionerName: "eab-enabled", want: true, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { h := &Handler{ auth: tc.auth, adminDB: tc.adminDB, acmeDB: nil, } got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) if (err != nil) != (tc.err != nil) { t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) return } if tc.err != nil { assert.Type(t, &linkedca.Provisioner{}, prov) assert.Type(t, &admin.Error{}, err) adminError, _ := err.(*admin.Error) assert.Equals(t, tc.err.Type, adminError.Type) assert.Equals(t, tc.err.Status, adminError.Status) assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode()) assert.Equals(t, tc.err.Message, adminError.Message) assert.Equals(t, tc.err.Detail, adminError.Detail) return } if got != tc.want { t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want) } }) } } func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { type fields struct { Reference string } tests := []struct { name string fields fields wantErr bool }{ { name: "fail/reference-too-long", fields: fields{ Reference: strings.Repeat("A", 257), }, wantErr: true, }, { name: "ok/empty-reference", fields: fields{ Reference: "", }, wantErr: false, }, { name: "ok", fields: fields{ Reference: "my-eab-reference", }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &CreateExternalAccountKeyRequest{ Reference: tt.fields.Reference, } if err := r.Validate(); (err != nil) != tt.wantErr { t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestHandler_CreateExternalAccountKey(t *testing.T) { type test struct { ctx context.Context statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("POST", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.CreateExternalAccountKey(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_DeleteExternalAccountKey(t *testing.T) { type test struct { ctx context.Context statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") chiCtx.URLParams.Add("id", "keyID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.DeleteExternalAccountKey(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_GetExternalAccountKeys(t *testing.T) { type test struct { ctx context.Context statusCode int req *http.Request err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") req := httptest.NewRequest("GET", "/foo", nil) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, req: req, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.GetExternalAccountKeys(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } }