diff --git a/api/read/read.go b/api/read/read.go index 9c5ebd07..72530b8c 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -16,7 +16,7 @@ import ( ) // JSON reads JSON from the request body and stores it in the value -// pointed by v. +// pointed to by v. func JSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { return errs.BadRequestErr(err, "error decoding json") @@ -34,9 +34,7 @@ func ProtoJSON(r io.Reader, m proto.Message) error { switch err := protojson.Unmarshal(data, m); { case errors.Is(err, proto.Error): - // trim the proto prefix for the message - s := strings.TrimSpace(strings.TrimPrefix(err.Error(), "proto:")) - return badProtoJSONError(s) + return badProtoJSONError(err.Error()) default: return err } @@ -59,9 +57,10 @@ func (e badProtoJSONError) Render(w http.ResponseWriter) { Detail string `json:"detail"` Message string `json:"message"` }{ - Type: "badRequest", - Detail: "bad request", - Message: e.Error(), + Type: "badRequest", + Detail: "bad request", + // trim the proto prefix for the message + Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")), } render.JSONStatus(w, v, http.StatusBadRequest) } diff --git a/api/read/read_test.go b/api/read/read_test.go index f2eff1bc..8696ba78 100644 --- a/api/read/read_test.go +++ b/api/read/read_test.go @@ -1,10 +1,22 @@ package read import ( + "encoding/json" + "errors" "io" + "io/ioutil" + "net/http" + "net/http/httptest" "reflect" "strings" "testing" + "testing/iotest" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "go.step.sm/linkedca" "github.com/smallstep/certificates/errs" ) @@ -44,3 +56,110 @@ func TestJSON(t *testing.T) { }) } } + +func TestProtoJSON(t *testing.T) { + + p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import? + + type args struct { + r io.Reader + m proto.Message + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "fail/io.ReadAll", + args: args{ + r: iotest.ErrReader(errors.New("read error")), + m: p, + }, + wantErr: true, + }, + { + name: "fail/proto", + args: args{ + r: strings.NewReader(`{?}`), + m: p, + }, + wantErr: true, + }, + { + name: "ok", + args: args{ + r: strings.NewReader(`{"x509":{}}`), + m: p, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ProtoJSON(tt.args.r, tt.args.m) + if (err != nil) != tt.wantErr { + t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + switch err.(type) { + case badProtoJSONError: + assert.Contains(t, err.Error(), "syntax error") + case *errs.Error: + var ee *errs.Error + if errors.As(err, &ee) { + assert.Equal(t, http.StatusBadRequest, ee.Status) + } + } + return + } + + assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m)) + assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m)) + }) + } +} + +func Test_badProtoJSONError_Render(t *testing.T) { + tests := []struct { + name string + e badProtoJSONError + expected string + }{ + { + name: "bad proto normal space", + e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), + expected: "syntax error (line 1:2): invalid value ?", + }, + { + name: "bad proto non breaking space", + e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), + expected: "syntax error (line 1:2): invalid value ?", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + w := httptest.NewRecorder() + tt.e.Render(w) + res := w.Result() + defer res.Body.Close() + + data, err := ioutil.ReadAll(res.Body) + assert.NoError(t, err) + + v := struct { + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + }{} + + assert.NoError(t, json.Unmarshal(data, &v)) + assert.Equal(t, "badRequest", v.Type) + assert.Equal(t, "bad request", v.Detail) + assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message) + + }) + } +}