From 8ce807a6cb409c0615cf9b9dfc2ba11455ec927d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 15:12:44 -0800 Subject: [PATCH] Modify errs.BadRequest() calls to always send an error to the client. --- api/api_test.go | 6 +++--- api/rekey.go | 3 +-- api/renew.go | 2 +- api/revoke.go | 4 ++-- api/revoke_test.go | 4 ++-- authority/provisioner/sshpop.go | 7 +++--- authority/provisioner/sshpop_test.go | 6 +++--- authority/ssh.go | 8 +++---- authority/ssh_test.go | 6 +++--- authority/tls.go | 6 ++++-- authority/tls_test.go | 2 +- ca/ca_test.go | 4 ++-- ca/client.go | 2 +- ca/client_test.go | 28 ++++++++++++------------ errs/error.go | 32 ++++++++++++++++++++++++++-- 15 files changed, 74 insertions(+), 46 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index 05d592f0..0fab1a5b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) { fields fields err error }{ - {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, + {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, - {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, + {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := errs.BadRequest("force") + expectedError400 := errs.BadRequestErr(errors.New("force")) expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) expectedError500 := errs.InternalServer("force") diff --git a/api/rekey.go b/api/rekey.go index c0d88e55..2b60eabc 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -26,9 +26,8 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing peer certificate")) + WriteError(w, errs.BadRequest("missing client certificate")) return } diff --git a/api/renew.go b/api/renew.go index 74ef2034..725322ee 100644 --- a/api/renew.go +++ b/api/renew.go @@ -10,7 +10,7 @@ import ( // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing peer certificate")) + WriteError(w, errs.BadRequest("missing client certificate")) return } diff --git a/api/revoke.go b/api/revoke.go index 21c3154c..f3f47ebb 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -80,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or peer certificate")) + WriteError(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body")) + WriteError(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. diff --git a/api/revoke_test.go b/api/revoke_test.go index 4ed4e3fe..b6ba30fb 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) { tests := map[string]test{ "error/missing serial": { rr: &RevokeRequest{}, - err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest}, }, "error/bad reasonCode": { rr: &RevokeRequest{ @@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) { ReasonCode: 15, Passive: true, }, - err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest}, }, "error/non-passive not implemented": { rr: &RevokeRequest{ diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 99974ff1..3039d2a3 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -191,8 +191,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { - return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " + - "must be equivalent to sshpop certificate serial number") + return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number") } return nil } @@ -205,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { - return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate") + return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, nil @@ -220,7 +219,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } if claims.sshCert.CertType != ssh.HostCert { - return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate") + return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ // Validate public key diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 3d343967..850a698d 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"), + err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."), } }, "ok": func(t *testing.T) test { @@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"), + err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), } }, "ok": func(t *testing.T) test { @@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"), + err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), } }, "ok": func(t *testing.T) test { diff --git a/authority/ssh.go b/authority/ssh.go index bef673bf..eba48297 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -69,7 +69,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin ts = a.templates.SSH.Host } default: - return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ) + return nil, errs.BadRequest("invalid certificate type '%s'", typ) } // Merge user and default data @@ -258,7 +258,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period") + return nil, errs.BadRequest("cannot renew a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { @@ -329,7 +329,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") + return nil, errs.BadRequest("cannot rekey a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { @@ -369,7 +369,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } signer = a.sshCAHostCertSignKey default: - return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType) + return nil, errs.BadRequest("unexpected certificate type '%d'", cert.CertType) } var err error diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 994d015f..a62c9e54 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), code: http.StatusBadRequest, } }, @@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), code: http.StatusBadRequest, } }, @@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), + err: errors.New("The request could not be completed: unexpected certificate type '0'."), code: http.StatusBadRequest, } }, diff --git a/authority/tls.go b/authority/tls.go index 839866a2..4a5f2fdf 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -433,8 +433,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error case db.ErrNotImplemented: return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) case db.ErrAlreadyExists: - return errs.BadRequest("authority.Revoke; certificate with serial "+ - "number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...) + return errs.ApplyOptions( + errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial), + opts..., + ) default: return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } diff --git a/authority/tls_test.go b/authority/tls_test.go index ba05b9fc..1796c4a3 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) { Reason: reason, OTT: raw, }, - err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"), + err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { assert.Equals(t, err.Details["token"], raw) diff --git a/ca/ca_test.go b/ca/ca_test.go index 0f7cb02e..64371ac3 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -588,7 +588,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: nil, status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "request-missing-peer-certificate": func(t *testing.T) *renewTest { @@ -596,7 +596,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "success": func(t *testing.T) *renewTest { diff --git a/ca/client.go b/ca/client.go index b10c0f86..74a3b7df 100644 --- a/ca/client.go +++ b/ca/client.go @@ -662,7 +662,7 @@ retry: // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { - return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") + return nil, errs.BadRequest("root certificate fingerprint does not match") } return &root, nil } diff --git a/ca/client_test.go b/ca/client_test.go index 187066f0..29a4848d 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -337,8 +337,8 @@ func TestClient_Sign(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, } srv := httptest.NewServer(nil) @@ -410,7 +410,7 @@ func TestClient_Revoke(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -455,7 +455,7 @@ func TestClient_Revoke(t *testing.T) { if got != nil { t.Errorf("Client.Revoke() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) @@ -484,8 +484,8 @@ func TestClient_Renew(t *testing.T) { }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -519,7 +519,7 @@ func TestClient_Renew(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -553,8 +553,8 @@ func TestClient_Rekey(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -588,7 +588,7 @@ func TestClient_Rekey(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -735,7 +735,7 @@ func TestClient_Roots(t *testing.T) { }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -768,7 +768,7 @@ func TestClient_Roots(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Roots() = %v, want %v", got, tt.response) @@ -1016,7 +1016,7 @@ func TestClient_SSHBastion(t *testing.T) { }{ {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, - {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -1050,7 +1050,7 @@ func TestClient_SSHBastion(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) } default: if !reflect.DeepEqual(got, tt.response) { diff --git a/errs/error.go b/errs/error.go index ebcf0894..ab488af1 100644 --- a/errs/error.go +++ b/errs/error.go @@ -194,6 +194,12 @@ var ( NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs ) +var ( + // BadRequestPrefix is the prefix added to the bad request messages that are + // directly sent to the cli. + BadRequestPrefix = "The request could not be completed: " +) + // splitOptionArgs splits the variadic length args into string formatting args // and Option(s) to apply to an Error. func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { @@ -218,6 +224,16 @@ func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { return args[:indexOptionStart], opts } +// New creates a new http error with the given status and message. +func New(status int, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + return &Error{ + Status: status, + Msg: msg, + Err: errors.New(msg), + } +} + // NewErr returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func NewErr(status int, err error, opts ...Option) error { @@ -254,6 +270,18 @@ func Errorf(code int, format string, args ...interface{}) error { return e } +// ApplyOptions applies the given options to the error if is the type *Error. +// TODO(mariano): try to get rid of this. +func ApplyOptions(err error, opts ...interface{}) error { + if e, ok := err.(*Error); ok { + _, o := splitOptionArgs(opts) + for _, fn := range o { + fn(e) + } + } + return err +} + // InternalServer creates a 500 error with the given format and arguments. func InternalServer(format string, args ...interface{}) error { args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) @@ -280,8 +308,8 @@ func NotImplementedErr(err error, opts ...Option) error { // BadRequest creates a 400 error with the given format and arguments. func BadRequest(format string, args ...interface{}) error { - args = append(args, withDefaultMessage(BadRequestDefaultMsg)) - return Errorf(http.StatusBadRequest, format, args...) + format = BadRequestPrefix + format + "." + return New(http.StatusBadRequest, format, args...) } // BadRequestErr returns an 400 error with the given error.