diff --git a/api/utils.go b/api/utils.go index b6ff7960..91091e25 100644 --- a/api/utils.go +++ b/api/utils.go @@ -2,14 +2,16 @@ package api import ( "encoding/json" + "errors" "io" "log" "net/http" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/logging" ) // EnableLogger is an interface that enables response logging for an object. @@ -114,3 +116,49 @@ func ReadProtoJSON(r io.Reader, m proto.Message) error { } return protojson.Unmarshal(data, m) } + +// ReadProtoJSONWithCheck reads JSON from the request body and stores it in the value +// pointed by v. TODO(hs): move this to and integrate with render package. +func ReadProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool { + data, err := io.ReadAll(r) + if err != nil { + var wrapper = struct { + Status int `json:"code"` + Message string `json:"message"` + }{ + Status: http.StatusBadRequest, + Message: err.Error(), + } + data, err := json.Marshal(wrapper) // TODO(hs): handle err; even though it's very unlikely to fail + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(data) + return false + } + if err := protojson.Unmarshal(data, m); err != nil { + if errors.Is(err, proto.Error) { + var wrapper = struct { + Message string `json:"message"` + }{ + Message: err.Error(), + } + data, err := json.Marshal(wrapper) // TODO(hs): handle err; even though it's very unlikely to fail + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(data) + return false + } + + // fallback to the default error writer + WriteError(w, err) + return false + } + + return true +} diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 34db5ea2..95b9ba98 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -26,8 +26,8 @@ type adminAuthority interface { UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) - StoreAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error - UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error + CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) RemoveAuthorityPolicy(ctx context.Context) error } diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index bcea31b5..d9592ff2 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -14,11 +14,13 @@ import ( "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/types/known/timestamppb" ) type mockAdminAuthority struct { @@ -39,7 +41,7 @@ type mockAdminAuthority struct { MockRemoveProvisioner func(ctx context.Context, id string) error MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) - MockStoreAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) (*linkedca.Policy, error) MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error MockRemoveAuthorityPolicy func(ctx context.Context) error } @@ -139,12 +141,12 @@ func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca. return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) StoreAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error { - return errors.New("not implemented yet") +func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error { - return errors.New("not implemented yet") +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") } func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error { diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index c30c7219..74bb2234 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -4,10 +4,11 @@ import ( "context" "net/http" + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin/db/nosql" - "go.step.sm/linkedca" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -42,7 +43,7 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - ctx := context.WithValue(r.Context(), admin.AdminContextKey, adm) + ctx := linkedca.WithAdmin(r.Context(), adm) next(w, r.WithContext(ctx)) } } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 2f64802f..30e05c48 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -4,10 +4,12 @@ import ( "net/http" "github.com/go-chi/chi" + + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) type policyAdminResponderInterface interface { @@ -82,29 +84,19 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r } var newPolicy = new(linkedca.Policy) - if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { - api.WriteError(w, err) + if !api.ReadProtoJSONWithCheck(w, r.Body, newPolicy) { return } - adm, err := adminFromContext(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving admin from context")) - return - } + adm := linkedca.AdminFromContext(ctx) - if err := par.auth.StoreAuthorityPolicy(ctx, adm, newPolicy); err != nil { + var createdPolicy *linkedca.Policy + if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return } - storedPolicy, err := par.auth.GetAuthorityPolicy(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) - return - } - - api.JSONStatus(w, storedPolicy, http.StatusCreated) + api.JSONStatus(w, createdPolicy, http.StatusCreated) } // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request @@ -134,24 +126,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r return } - adm, err := adminFromContext(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving admin from context")) - return - } + adm := linkedca.AdminFromContext(ctx) - if err := par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + var updatedPolicy *linkedca.Policy + if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) return } - newlyStoredPolicy, err := par.auth.GetAuthorityPolicy(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) - return - } - - api.ProtoJSONStatus(w, newlyStoredPolicy, http.StatusOK) + api.ProtoJSONStatus(w, updatedPolicy, http.StatusOK) } // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request diff --git a/authority/policy.go b/authority/policy.go index db44e5f4..ee132f31 100644 --- a/authority/policy.go +++ b/authority/policy.go @@ -25,42 +25,42 @@ func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, e return policy, nil } -func (a *Authority) StoreAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) error { +func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.checkPolicy(ctx, adm, policy); err != nil { - return err + return nil, err } if err := a.adminDB.CreateAuthorityPolicy(ctx, policy); err != nil { - return err + return nil, err } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy") + return nil, admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy") } - return nil + return policy, nil // TODO: return the newly stored policy } -func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) error { +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.checkPolicy(ctx, adm, policy); err != nil { - return err + return nil, err } if err := a.adminDB.UpdateAuthorityPolicy(ctx, policy); err != nil { - return err + return nil, err } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy") + return nil, admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy") } - return nil + return policy, nil // TODO: return the updated stored policy } func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { diff --git a/go.sum b/go.sum index ba7cb531..e7681592 100644 --- a/go.sum +++ b/go.sum @@ -639,8 +639,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= diff --git a/policy/engine.go b/policy/engine.go index 63d8452a..c37e1f59 100755 --- a/policy/engine.go +++ b/policy/engine.go @@ -4,7 +4,9 @@ import ( "bytes" "crypto/x509" "crypto/x509/pkix" + "errors" "fmt" + "io" "net" "net/url" "reflect" @@ -40,6 +42,30 @@ type NamePolicyError struct { Detail string } +type NameError struct { + error + Reason NamePolicyReason +} + +func a() { + err := io.EOF + var ne *NameError + errors.As(err, ne) + errors.Is(err, ne) +} + +func newPolicyError(reason NamePolicyReason, err error) error { + return &NameError{ + error: err, + Reason: reason, + } +} + +func newPolicyErrorf(reason NamePolicyReason, format string, args ...interface{}) error { + err := fmt.Errorf(format, args...) + return newPolicyError(reason, err) +} + func (e *NamePolicyError) Error() string { switch e.Reason { case NotAuthorizedForThisName: diff --git a/policy/engine_test.go b/policy/engine_test.go index f7a4b20a..cf406e71 100755 --- a/policy/engine_test.go +++ b/policy/engine_test.go @@ -8,8 +8,9 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/smallstep/assert" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" ) // TODO(hs): the functionality in the policy engine is a nice candidate for trying fuzzing on