diff --git a/authority/admin/api/policy_test.go b/authority/admin/api/policy_test.go index b5987104..77879190 100644 --- a/authority/admin/api/policy_test.go +++ b/authority/admin/api/policy_test.go @@ -11,11 +11,11 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" "google.golang.org/protobuf/encoding/protojson" "go.step.sm/linkedca" - "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" @@ -29,13 +29,67 @@ func (f *fakeLinkedCA) IsLinkedCA() bool { return true } +// testAdminError is an error type that models the expected +// error body returned. +type testAdminError struct { + Type string `json:"type"` + Message string `json:"message"` + Detail string `json:"detail"` +} + +type testX509Policy struct { + Allow *testX509Names `json:"allow,omitempty"` + Deny *testX509Names `json:"deny,omitempty"` + AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` +} + +type testX509Names struct { + CommonNames []string `json:"commonNames,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ips,omitempty"` + EmailAddresses []string `json:"emails,omitempty"` + URIDomains []string `json:"uris,omitempty"` +} + +type testSSHPolicy struct { + User *testSSHUserPolicy `json:"user,omitempty"` + Host *testSSHHostPolicy `json:"host,omitempty"` +} + +type testSSHHostPolicy struct { + Allow *testSSHHostNames `json:"allow,omitempty"` + Deny *testSSHHostNames `json:"deny,omitempty"` +} + +type testSSHHostNames struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ips,omitempty"` + Principals []string `json:"principals,omitempty"` +} + +type testSSHUserPolicy struct { + Allow *testSSHUserNames `json:"allow,omitempty"` + Deny *testSSHUserNames `json:"deny,omitempty"` +} + +type testSSHUserNames struct { + EmailAddresses []string `json:"emails,omitempty"` + Principals []string `json:"principals,omitempty"` +} + +// testPolicyResponse models the Policy API JSON response +type testPolicyResponse struct { + X509 *testX509Policy `json:"x509,omitempty"` + SSH *testSSHPolicy `json:"ssh,omitempty"` +} + func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB ctx context.Context err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -85,7 +139,42 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ - Dns: []string{"*.local"}, + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, }, }, } @@ -96,7 +185,48 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { return policy, nil }, }, - policy: policy, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, statusCode: 200, } }, @@ -114,29 +244,31 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { par.GetAuthorityPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.Message, ae.Message) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + assert.Equal(t, tc.response, p) }) } } @@ -149,7 +281,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { ctx context.Context acmeDB acme.DB err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -227,7 +359,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -272,7 +404,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -315,7 +447,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -336,8 +468,14 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { }, nil }, }, - body: body, - policy: policy, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, statusCode: 201, } }, @@ -355,21 +493,21 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { par.CreateAuthorityPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", @@ -377,15 +515,18 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { - assert.Equals(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.Message, ae.Message) } return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -399,7 +540,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { ctx context.Context acmeDB acme.DB err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -485,7 +626,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -530,7 +671,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -573,7 +714,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -594,8 +735,14 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { }, nil }, }, - body: body, - policy: policy, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, statusCode: 200, } }, @@ -613,21 +760,21 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { par.UpdateAuthorityPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", @@ -635,15 +782,18 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { - assert.Equals(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.Message, ae.Message) } return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -764,32 +914,32 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { par.DeleteAuthorityPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.Message, ae.Message) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) - assert.FatalError(t, err) + assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) - assert.Equals(t, "ok", response.Status) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) } @@ -802,7 +952,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { ctx context.Context acmeDB acme.DB err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -832,7 +982,42 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ - Dns: []string{"*.local"}, + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, }, }, } @@ -841,8 +1026,49 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, - policy: policy, + ctx: ctx, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, statusCode: 200, } }, @@ -860,28 +1086,31 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { par.GetProvisionerPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.Message, ae.Message) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -894,7 +1123,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body []byte ctx context.Context err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -964,7 +1193,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -999,7 +1228,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -1032,7 +1261,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -1040,8 +1269,14 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { return nil }, }, - body: body, - policy: policy, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, statusCode: 201, } }, @@ -1059,21 +1294,21 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { par.CreateProvisionerPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", @@ -1081,15 +1316,18 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { - assert.Equals(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.Message, ae.Message) } return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -1102,7 +1340,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { adminDB admin.DB ctx context.Context err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1173,7 +1411,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating provisioner policy") adminErr.Message = "error updating provisioner policy: admin lock out" body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -1209,7 +1447,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating provisioner policy: force") adminErr.Message = "error updating provisioner policy: force" body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -1243,7 +1481,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ @@ -1251,8 +1489,14 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { return nil }, }, - body: body, - policy: policy, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, statusCode: 200, } }, @@ -1270,21 +1514,21 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { par.UpdateProvisionerPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", @@ -1292,15 +1536,18 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { - assert.Equals(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.Message, ae.Message) } return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -1391,32 +1638,32 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { par.DeleteProvisionerPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.Message, ae.Message) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) - assert.FatalError(t, err) + assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) - assert.Equals(t, "ok", response.Status) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) } @@ -1428,7 +1675,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { acmeDB acme.DB adminDB admin.DB err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1464,7 +1711,42 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ - Dns: []string{"*.local"}, + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, }, }, } @@ -1478,8 +1760,49 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ - ctx: ctx, - policy: policy, + ctx: ctx, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, statusCode: 200, } }, @@ -1497,28 +1820,31 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { par.GetACMEAccountPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.Message, ae.Message) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -1531,7 +1857,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body []byte ctx context.Context err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1610,13 +1936,13 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { - assert.Equals(t, "provID", provisionerID) - assert.Equals(t, "eakID", eak.ID) + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) return errors.New("force") }, }, @@ -1643,18 +1969,24 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { }, } body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { - assert.Equals(t, "provID", provisionerID) - assert.Equals(t, "eakID", eak.ID) + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) return nil }, }, - body: body, - policy: policy, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, statusCode: 201, } }, @@ -1672,21 +2004,21 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { par.CreateACMEAccountPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", @@ -1694,15 +2026,18 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { - assert.Equals(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.Message, ae.Message) } return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -1715,7 +2050,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body []byte ctx context.Context err *admin.Error - policy *linkedca.Policy + response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1795,13 +2130,13 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating ACME EAK policy: force") adminErr.Message = "error updating ACME EAK policy: force" body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { - assert.Equals(t, "provID", provisionerID) - assert.Equals(t, "eakID", eak.ID) + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) return errors.New("force") }, }, @@ -1829,18 +2164,24 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) body, err := protojson.Marshal(policy) - assert.FatalError(t, err) + assert.NoError(t, err) return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { - assert.Equals(t, "provID", provisionerID) - assert.Equals(t, "eakID", eak.ID) + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) return nil }, }, - body: body, - policy: policy, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, statusCode: 200, } }, @@ -1858,21 +2199,21 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { par.UpdateACMEAccountPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", @@ -1880,15 +2221,18 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { - assert.Equals(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.Message, ae.Message) } return } - p := &linkedca.Policy{} - assert.FatalError(t, readProtoJSON(res.Body, p)) - assert.Equals(t, tc.policy, p) + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) }) } @@ -1957,8 +2301,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { ctx: ctx, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { - assert.Equals(t, "provID", provisionerID) - assert.Equals(t, "eakID", eak.ID) + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) return errors.New("force") }, }, @@ -1988,8 +2332,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { ctx: ctx, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { - assert.Equals(t, "provID", provisionerID) - assert.Equals(t, "eakID", eak.ID) + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) return nil }, }, @@ -2010,32 +2354,32 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { par.DeleteACMEAccountPolicy(w, req) res := w.Result() - assert.Equals(t, tc.statusCode, res.StatusCode) + assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() - assert.FatalError(t, err) + assert.NoError(t, err) - ae := admin.Error{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - assert.Equals(t, tc.err.Type, ae.Type) - assert.Equals(t, tc.err.Message, ae.Message) - assert.Equals(t, tc.err.StatusCode(), res.StatusCode) - assert.Equals(t, tc.err.Detail, ae.Detail) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) - assert.FatalError(t, err) + assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) - assert.Equals(t, "ok", response.Status) - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) }