diff --git a/api/read/read.go b/api/read/read.go index de92c5d7..ee72cdb7 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -2,30 +2,91 @@ package read import ( + "bytes" "encoding/json" - "io" + "fmt" + "net/http" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/admin" + + "github.com/smallstep/certificates/internal/buffer" ) -// JSON reads JSON from the request body and stores it in the value -// pointed by v. -func JSON(r io.Reader, v interface{}) error { - if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.BadRequestErr(err, "error decoding json") +// JSON unmarshals from the given request's JSON body into v. In case of an +// error a HTTP Bad Request error will be written to w. +func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool { + b := read(w, r) + if b == nil { + return false } - return nil + defer buffer.Put(b) + + if err := json.NewDecoder(b).Decode(v); err != nil { + err = fmt.Errorf("error decoding json: %w", err) + + render.BadRequest(w, err) + + return false + } + + return true +} + +// AdminJSON is obsolete; it's here for backwards compatibility. +// +// Please don't use. +func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool { + b := read(w, r) + if b == nil { + return false + } + defer buffer.Put(b) + + if err := json.NewDecoder(b).Decode(v); err != nil { + e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body") + admin.WriteError(w, e) + + return false + } + + return true } // ProtoJSON reads JSON from the request body and stores it in the value // pointed by v. -func ProtoJSON(r io.Reader, m proto.Message) error { - data, err := io.ReadAll(r) - if err != nil { - return errs.BadRequestErr(err, "error reading request body") +func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool { + b := read(w, r) + if b == nil { + return false } - return protojson.Unmarshal(data, m) + defer buffer.Put(b) + + if err := protojson.Unmarshal(b.Bytes(), m); err != nil { + err = fmt.Errorf("error decoding proto json: %w", err) + + render.BadRequest(w, err) + + return false + } + + return true +} + +func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer { + b := buffer.Get() + if _, err := b.ReadFrom(r.Body); err != nil { + buffer.Put(b) + + err = fmt.Errorf("error reading request body: %w", err) + + render.BadRequest(w, err) + + return nil + } + + return b } diff --git a/api/read/read_test.go b/api/read/read_test.go index f2eff1bc..3fde9d60 100644 --- a/api/read/read_test.go +++ b/api/read/read_test.go @@ -2,45 +2,56 @@ package read import ( "io" - "reflect" + "net/http" + "net/http/httptest" + "strconv" "strings" "testing" + "testing/iotest" - "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" ) func TestJSON(t *testing.T) { - type args struct { - r io.Reader - v interface{} - } - tests := []struct { - name string - args args - wantErr bool + cases := []struct { + src io.Reader + exp interface{} + ok bool + code int }{ - {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, - {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, + 0: { + src: strings.NewReader(`{"foo":"bar"}`), + exp: map[string]interface{}{"foo": "bar"}, + ok: true, + code: http.StatusOK, + }, + 1: { + src: strings.NewReader(`{"foo"}`), + code: http.StatusBadRequest, + }, + 2: { + src: io.MultiReader( + strings.NewReader(`{`), + iotest.ErrReader(assert.AnError), + strings.NewReader(`"foo":"bar"}`), + ), + code: http.StatusBadRequest, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := JSON(tt.args.r, &tt.args.v) - if (err != nil) != tt.wantErr { - t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - e, ok := err.(*errs.Error) - if ok { - if code := e.StatusCode(); code != 400 { - t.Errorf("error.StatusCode() = %v, wants 400", code) - } - } else { - t.Errorf("error type = %T, wants *Error", err) - } - } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { - t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) - } + for caseIndex := range cases { + kase := cases[caseIndex] + + t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", kase.src) + rec := httptest.NewRecorder() + + var body interface{} + got := JSON(rec, req, &body) + + assert.Equal(t, kase.ok, got) + assert.Equal(t, kase.code, rec.Result().StatusCode) + assert.Equal(t, kase.exp, body) }) } }