package api import ( "bytes" "context" "encoding/json" "io" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/go-chi/chi" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "go.step.sm/linkedca" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/admin" ) func readProtoJSON(r io.ReadCloser, m proto.Message) error { defer r.Close() data, err := io.ReadAll(r) if err != nil { return err } return protojson.Unmarshal(data, m) } func mockMustAuthority(t *testing.T, a adminAuthority) { t.Helper() fn := mustAuthority t.Cleanup(func() { mustAuthority = fn }) mustAuthority = func(ctx context.Context) adminAuthority { return a } } func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/prov.GetDetails": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") err.Message = "error getting ACME details for provisioner 'provName'" return test{ ctx: ctx, err: err, statusCode: 500, } }, "fail/prov.GetDetails.GetACME": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") err.Message = "error getting ACME details for provisioner 'provName'" return test{ ctx: ctx, err: err, statusCode: 500, } }, "ok/eab-disabled": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: false, }, }, }, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") err.Message = "ACME EAB not enabled for provisioner 'provName'" return test{ ctx: ctx, err: err, statusCode: 400, } }, "ok/eab-enabled": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: true, }, }, }, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx) w := httptest.NewRecorder() requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { type fields struct { Reference string } tests := []struct { name string fields fields wantErr bool }{ { name: "fail/reference-too-long", fields: fields{ Reference: strings.Repeat("A", 257), }, wantErr: true, }, { name: "ok/empty-reference", fields: fields{ Reference: "", }, wantErr: false, }, { name: "ok", fields: fields{ Reference: "my-eab-reference", }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &CreateExternalAccountKeyRequest{ Reference: tt.fields.Reference, } if err := r.Validate(); (err != nil) != tt.wantErr { t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestHandler_CreateExternalAccountKey(t *testing.T) { type test struct { ctx context.Context statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("POST", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.CreateExternalAccountKey(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_DeleteExternalAccountKey(t *testing.T) { type test struct { ctx context.Context statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") chiCtx.URLParams.Add("id", "keyID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.DeleteExternalAccountKey(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_GetExternalAccountKeys(t *testing.T) { type test struct { ctx context.Context statusCode int req *http.Request err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") req := httptest.NewRequest("GET", "/foo", nil) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, req: req, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.GetExternalAccountKeys(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func Test_eakToLinked(t *testing.T) { tests := []struct { name string k *acme.ExternalAccountKey want *linkedca.EABKey }{ { name: "no-key", k: nil, want: nil, }, { name: "no-policy", k: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: nil, }, want: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: nil, }, }, { name: "with-policy", k: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.local"}, IPRanges: []string{"10.0.0.0/24"}, }, Denied: acme.PolicyNames{ DNSNames: []string{"badhost.local"}, IPRanges: []string{"10.0.0.30"}, }, }, }, }, want: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/24"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"10.0.0.30"}, }, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) { t.Errorf("eakToLinked() = %v, want %v", got, tt.want) } }) } } func Test_linkedEAKToCertificates(t *testing.T) { tests := []struct { name string k *linkedca.EABKey want *acme.ExternalAccountKey }{ { name: "no-key", k: nil, want: nil, }, { name: "no-policy", k: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: nil, }, want: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: nil, }, }, { name: "no-x509-policy", k: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: &linkedca.Policy{}, }, want: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: &acme.Policy{}, }, }, { name: "with-x509-policy", k: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/24"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"10.0.0.30"}, }, AllowWildcardNames: true, }, }, }, want: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.local"}, IPRanges: []string{"10.0.0.0/24"}, }, Denied: acme.PolicyNames{ DNSNames: []string{"badhost.local"}, IPRanges: []string{"10.0.0.30"}, }, AllowWildcardNames: true, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) { t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want) } }) } }