Improve handling of bad JSON protobuf bodies

pull/788/head
Herman Slatman 2 years ago
parent 2ca5c0170f
commit def9438ad6
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -10,7 +10,6 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
@ -24,62 +23,55 @@ func JSON(r io.Reader, v interface{}) error {
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
// pointed to by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}
// ProtoJSONWithCheck reads JSON from the request body and stores it in the value
// pointed to by m. Returns false if an error was written; true if not.
// TODO(hs): refactor this after the API flow changes are in (or before if that works)
func ProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool {
data, err := io.ReadAll(r)
if err != nil {
var wrapper = struct {
Status int `json:"code"`
Message string `json:"message"`
}{
Status: http.StatusBadRequest,
Message: err.Error(),
}
errData, err := json.Marshal(wrapper)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
return false
}
if err := protojson.Unmarshal(data, m); err != nil {
if errors.Is(err, proto.Error) {
var wrapper = struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{
Type: "badRequest",
Detail: "bad request",
Message: err.Error(),
}
errData, err := json.Marshal(wrapper)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
return false
return newBadProtoJSONError(err)
}
}
return err
}
// fallback to the default error writer
render.Error(w, err)
return false
// BadProtoJSONError is an error type that is used when a proto
// message cannot be unmarshaled. Usually this is caused by an error
// in the request body.
type BadProtoJSONError struct {
err error
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}
// newBadProtoJSONError returns a new instance of BadProtoJSONError
// This error type is always caused by an error in the request body.
func newBadProtoJSONError(err error) *BadProtoJSONError {
return &BadProtoJSONError{
err: err,
Type: "badRequest",
Detail: "bad request",
Message: err.Error(),
}
}
// Error implements the error interface
func (e *BadProtoJSONError) Error() string {
return e.err.Error()
}
// Render implements render.RenderableError for BadProtoError
func (e *BadProtoJSONError) Render(w http.ResponseWriter) {
errData, err := json.Marshal(e)
if err != nil {
panic(err)
}
return true
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
}

@ -80,7 +80,8 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
}
var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
@ -120,7 +121,8 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
}
var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
@ -195,7 +197,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
}
var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
@ -228,7 +231,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
}
var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
@ -297,7 +301,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
}
var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
@ -324,7 +329,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
}
var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}

@ -167,7 +167,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
statusCode: 409,
}
},
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
"fail/read.ProtoJSON": func(t *testing.T) test {
ctx := context.Background()
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
@ -410,7 +410,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
statusCode: 404,
}
},
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
"fail/read.ProtoJSON": func(t *testing.T) test {
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
@ -871,7 +871,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
statusCode: 409,
}
},
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
"fail/read.ProtoJSON": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Name: "provName",
}
@ -1060,7 +1060,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
statusCode: 404,
}
},
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
"fail/read.ProtoJSON": func(t *testing.T) test {
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
@ -1472,7 +1472,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
statusCode: 409,
}
},
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
"fail/read.ProtoJSON": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Name: "provName",
}
@ -1637,7 +1637,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
statusCode: 404,
}
},
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
"fail/read.ProtoJSON": func(t *testing.T) test {
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{

@ -8,18 +8,21 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestHandler_GetProvisioner(t *testing.T) {
@ -335,12 +338,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
return test{
ctx: context.Background(),
body: body,
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
Status: 500,
Detail: "",
Message: "",
statusCode: 400,
err: &admin.Error{
Type: "badRequest",
Status: 400,
Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
},
}
},
@ -423,9 +426,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return
}
@ -616,12 +625,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{
ctx: context.Background(),
body: body,
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
Status: 500,
Detail: "",
Message: "",
statusCode: 400,
err: &admin.Error{
Type: "badRequest",
Status: 400,
Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
},
}
},
@ -1074,9 +1083,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return
}

Loading…
Cancel
Save