diff --git a/api/api.go b/api/api.go index 0c16168f..33aa0f44 100644 --- a/api/api.go +++ b/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/cli/crypto/tlsutil" ) @@ -233,13 +234,13 @@ type ProvisionerKeyResponse struct { // or an error if something is wrong. func (s *SignRequest) Validate() error { if s.CsrPEM.CertificateRequest == nil { - return BadRequest(errors.New("missing csr")) + return errs.BadRequest(errors.New("missing csr")) } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return BadRequest(errors.Wrap(err, "invalid csr")) + return errs.BadRequest(errors.Wrap(err, "invalid csr")) } if s.OTT == "" { - return BadRequest(errors.New("missing ott")) + return errs.BadRequest(errors.New("missing ott")) } return nil @@ -328,7 +329,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { - WriteError(w, NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI))) + WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI))) return } @@ -349,7 +350,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } @@ -366,13 +367,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { signOpts, err := h.Authority.AuthorizeSign(body.OTT) if err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } certChainPEM := certChainToPEM(certChain) @@ -393,13 +394,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing peer certificate"))) + WriteError(w, errs.BadRequest(errors.New("missing peer certificate"))) return } certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } certChainPEM := certChainToPEM(certChain) @@ -421,13 +422,13 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := parseCursor(r) if err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } p, next, err := h.Authority.GetProvisioners(cursor, limit) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } JSON(w, &ProvisionersResponse{ @@ -441,7 +442,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := h.Authority.GetEncryptedKey(kid) if err != nil { - WriteError(w, NotFound(err)) + WriteError(w, errs.NotFound(err)) return } JSON(w, &ProvisionerKeyResponse{key}) @@ -451,7 +452,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } @@ -469,7 +470,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } diff --git a/api/errors.go b/api/errors.go index 90b41565..93057ed2 100644 --- a/api/errors.go +++ b/api/errors.go @@ -8,106 +8,10 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) -// StatusCoder interface is used by errors that returns the HTTP response code. -type StatusCoder interface { - StatusCode() int -} - -// StackTracer must be by those errors that return an stack trace. -type StackTracer interface { - StackTrace() errors.StackTrace -} - -// Error represents the CA API errors. -type Error struct { - Status int - Err error -} - -// ErrorResponse represents an error in JSON format. -type ErrorResponse struct { - Status int `json:"status"` - Message string `json:"message"` -} - -// Cause implements the errors.Causer interface and returns the original error. -func (e *Error) Cause() error { - return e.Err -} - -// Error implements the error interface and returns the error string. -func (e *Error) Error() string { - return e.Err.Error() -} - -// StatusCode implements the StatusCoder interface and returns the HTTP response -// code. -func (e *Error) StatusCode() int { - return e.Status -} - -// MarshalJSON implements json.Marshaller interface for the Error struct. -func (e *Error) MarshalJSON() ([]byte, error) { - return json.Marshal(&ErrorResponse{Status: e.Status, Message: http.StatusText(e.Status)}) -} - -// UnmarshalJSON implements json.Unmarshaler interface for the Error struct. -func (e *Error) UnmarshalJSON(data []byte) error { - var er ErrorResponse - if err := json.Unmarshal(data, &er); err != nil { - return err - } - e.Status = er.Status - e.Err = fmt.Errorf(er.Message) - return nil -} - -// NewError returns a new Error. If the given error implements the StatusCoder -// interface we will ignore the given status. -func NewError(status int, err error) error { - if sc, ok := err.(StatusCoder); ok { - return &Error{Status: sc.StatusCode(), Err: err} - } - cause := errors.Cause(err) - if sc, ok := cause.(StatusCoder); ok { - return &Error{Status: sc.StatusCode(), Err: err} - } - return &Error{Status: status, Err: err} -} - -// InternalServerError returns a 500 error with the given error. -func InternalServerError(err error) error { - return NewError(http.StatusInternalServerError, err) -} - -// NotImplemented returns a 500 error with the given error. -func NotImplemented(err error) error { - return NewError(http.StatusNotImplemented, err) -} - -// BadRequest returns an 400 error with the given error. -func BadRequest(err error) error { - return NewError(http.StatusBadRequest, err) -} - -// Unauthorized returns an 401 error with the given error. -func Unauthorized(err error) error { - return NewError(http.StatusUnauthorized, err) -} - -// Forbidden returns an 403 error with the given error. -func Forbidden(err error) error { - return NewError(http.StatusForbidden, err) -} - -// NotFound returns an 404 error with the given error. -func NotFound(err error) error { - return NewError(http.StatusNotFound, err) -} - // WriteError writes to w a JSON representation of the given error. func WriteError(w http.ResponseWriter, err error) { switch k := err.(type) { @@ -118,10 +22,10 @@ func WriteError(w http.ResponseWriter, err error) { w.Header().Set("Content-Type", "application/json") } cause := errors.Cause(err) - if sc, ok := err.(StatusCoder); ok { + if sc, ok := err.(errs.StatusCoder); ok { w.WriteHeader(sc.StatusCode()) } else { - if sc, ok := cause.(StatusCoder); ok { + if sc, ok := cause.(errs.StatusCoder); ok { w.WriteHeader(sc.StatusCode()) } else { w.WriteHeader(http.StatusInternalServerError) @@ -134,12 +38,12 @@ func WriteError(w http.ResponseWriter, err error) { "error": err, }) if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.(StackTracer); ok { + if e, ok := err.(errs.StackTracer); ok { rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e), }) } else { - if e, ok := cause.(StackTracer); ok { + if e, ok := cause.(errs.StackTracer); ok { rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e), }) diff --git a/api/revoke.go b/api/revoke.go index aceb8305..df974cbe 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "golang.org/x/crypto/ocsp" ) @@ -29,13 +30,13 @@ type RevokeRequest struct { // or an error if something is wrong. func (r *RevokeRequest) Validate() (err error) { if r.Serial == "" { - return BadRequest(errors.New("missing serial")) + return errs.BadRequest(errors.New("missing serial")) } if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { - return BadRequest(errors.New("reasonCode out of bounds")) + return errs.BadRequest(errors.New("reasonCode out of bounds")) } if !r.Passive { - return NotImplemented(errors.New("non-passive revocation not implemented")) + return errs.NotImplemented(errors.New("non-passive revocation not implemented")) } return @@ -49,7 +50,7 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } @@ -71,7 +72,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { if len(body.OTT) > 0 { logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } opts.OTT = body.OTT @@ -80,12 +81,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing ott or peer certificate"))) + WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate"))) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, BadRequest(errors.New("revoke: serial number in mtls certificate different than body"))) + WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body"))) return } // TODO: should probably be checking if the certificate was revoked here. @@ -96,7 +97,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } diff --git a/api/ssh.go b/api/ssh.go index 546c8f1e..f125a95a 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" "golang.org/x/crypto/ssh" @@ -248,19 +249,19 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) return } @@ -268,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey"))) return } } @@ -284,13 +285,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } @@ -298,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -319,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } certChain, err := h.Authority.Sign(cr, opts, signOpts...) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } identityCertificate = certChainToPEM(certChain) @@ -342,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHRoots() if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, NotFound(errors.New("no keys found"))) + WriteError(w, errs.NotFound(errors.New("no keys found"))) return } @@ -367,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation() if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, NotFound(errors.New("no keys found"))) + WriteError(w, errs.NotFound(errors.New("no keys found"))) return } @@ -392,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } if err := body.Validate(); err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } ts, err := h.Authority.GetSSHConfig(body.Type, body.Data) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } @@ -413,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: config.HostTemplates = ts default: - WriteError(w, InternalServerError(errors.New("it should hot get here"))) + WriteError(w, errs.InternalServerError(errors.New("it should hot get here"))) return } @@ -424,17 +425,17 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } JSON(w, &SSHCheckPrincipalResponse{ @@ -451,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { hosts, err := h.Authority.GetSSHHosts(cert) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } JSON(w, &SSHGetHostsResponse{ @@ -463,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } if err := body.Validate(); err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index 234a6df5..6b7ef5d7 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" ) @@ -38,36 +39,36 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) return } ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) } newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index 4324ebba..5a847796 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" ) // SSHRenewRequest is the request body of an SSH certificate request. @@ -34,30 +35,30 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, BadRequest(err)) + WriteError(w, errs.BadRequest(err)) return } ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, InternalServerError(err)) + WriteError(w, errs.InternalServerError(err)) } newCert, err := h.Authority.RenewSSH(oldCert) if err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 9355e5a4..93e0e450 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "golang.org/x/crypto/ocsp" ) @@ -29,16 +30,16 @@ type SSHRevokeRequest struct { // or an error if something is wrong. func (r *SSHRevokeRequest) Validate() (err error) { if r.Serial == "" { - return BadRequest(errors.New("missing serial")) + return errs.BadRequest(errors.New("missing serial")) } if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { - return BadRequest(errors.New("reasonCode out of bounds")) + return errs.BadRequest(errors.New("reasonCode out of bounds")) } if !r.Passive { - return NotImplemented(errors.New("non-passive revocation not implemented")) + return errs.NotImplemented(errors.New("non-passive revocation not implemented")) } if len(r.OTT) == 0 { - return BadRequest(errors.New("missing ott")) + return errs.BadRequest(errors.New("missing ott")) } return } @@ -49,7 +50,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) return } @@ -70,13 +71,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, Unauthorized(err)) + WriteError(w, errs.Unauthorized(err)) return } opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, Forbidden(err)) + WriteError(w, errs.Forbidden(err)) return } diff --git a/api/utils.go b/api/utils.go index 89adedb7..56beb2b5 100644 --- a/api/utils.go +++ b/api/utils.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) @@ -68,7 +69,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) { // pointed by v. func ReadJSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { - return BadRequest(errors.Wrap(err, "error decoding json")) + return errs.BadRequest(errors.Wrap(err, "error decoding json")) } return nil } diff --git a/authority/ssh.go b/authority/ssh.go index fbf97545..8148a6bd 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" "github.com/smallstep/cli/crypto/randutil" @@ -660,25 +661,19 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st if a.sshCheckHostFunc != nil { exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates()) if err != nil { - return false, &apiError{ - err: errors.Wrap(err, "checkSSHHost: error from injected checkSSHHost func"), - code: http.StatusInternalServerError, - } + return false, errs.Wrap(http.StatusInternalServerError, err, + "checkSSHHost: error from injected checkSSHHost func") } return exists, nil } exists, err := a.db.IsSSHHost(principal) if err != nil { if err == db.ErrNotImplemented { - return false, &apiError{ - err: errors.Wrap(err, "checkSSHHost: isSSHHost is not implemented"), - code: http.StatusNotImplemented, - } - } - return false, &apiError{ - err: errors.Wrap(err, "checkSSHHost: error checking if hosts exists"), - code: http.StatusInternalServerError, + return false, errs.Wrap(http.StatusNotImplemented, err, + "checkSSHHost: isSSHHost is not implemented") } + return false, errs.Wrap(http.StatusInternalServerError, err, + "checkSSHHost: error checking if hosts exists") } return exists, nil diff --git a/ca/client.go b/ca/client.go index a3bda21b..d42b1bd4 100644 --- a/ca/client.go +++ b/ca/client.go @@ -26,6 +26,7 @@ import ( "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca/identity" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/config" "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/pemutil" @@ -134,7 +135,7 @@ func (o *clientOptions) applyDefaultIdentity() error { } crt, err := i.TLSCertificate() if err != nil { - return nil + return err } o.certificate = crt return nil @@ -472,11 +473,6 @@ func (c *Client) GetRootCAs() *x509.CertPool { } } -// GetTransport returns the transport of the internal HTTP client. -func (c *Client) GetTransport() http.RoundTripper { - return c.client.GetTransport() -} - // SetTransport updates the transport of the internal HTTP client. func (c *Client) SetTransport(tr http.RoundTripper) { c.client.SetTransport(tr) @@ -958,24 +954,27 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin Token: token, }) if err != nil { - return nil, errors.Wrap(err, "error marshaling request") + return nil, errs.Wrap(http.StatusInternalServerError, err, + "error marshaling check-host request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"}) retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { - return nil, errors.Wrapf(err, "client POST %s failed", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u, + errs.WithMessage("Failed to perform POST request to %s", u)) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } - return nil, readError(resp.Body) + + return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body)) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u) } return &check, nil } @@ -1174,7 +1173,7 @@ func readJSON(r io.ReadCloser, v interface{}) error { func readError(r io.ReadCloser) error { defer r.Close() - apiErr := new(api.Error) + apiErr := new(errs.Error) if err := json.NewDecoder(r).Decode(apiErr); err != nil { return err } diff --git a/errs/error.go b/errs/error.go new file mode 100644 index 00000000..825cf549 --- /dev/null +++ b/errs/error.go @@ -0,0 +1,250 @@ +package errs + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" +) + +// StatusCoder interface is used by errors that returns the HTTP response code. +type StatusCoder interface { + StatusCode() int +} + +// StackTracer must be by those errors that return an stack trace. +type StackTracer interface { + StackTrace() errors.StackTrace +} + +// Option modifies the Error type. +type Option func(e *Error) error + +// WithMessage returns an Option that modifies the error by overwriting the +// message only if it is empty. +func WithMessage(format string, args ...interface{}) Option { + return func(e *Error) error { + if len(e.Msg) > 0 { + return e + } + e.Msg = fmt.Sprintf(format, args...) + return e + } +} + +// Error represents the CA API errors. +type Error struct { + Status int + Err error + Msg string +} + +// New returns a new Error. If the given error implements the StatusCoder +// interface we will ignore the given status. +func New(status int, err error, opts ...Option) error { + var e *Error + if sc, ok := err.(StatusCoder); ok { + e = &Error{Status: sc.StatusCode(), Err: err} + } else { + cause := errors.Cause(err) + if sc, ok := cause.(StatusCoder); ok { + e = &Error{Status: sc.StatusCode(), Err: err} + } else { + e = &Error{Status: status, Err: err} + } + } + for _, o := range opts { + o(e) + } + return e +} + +// ErrorResponse represents an error in JSON format. +type ErrorResponse struct { + Status int `json:"status"` + Message string `json:"message"` +} + +// Cause implements the errors.Causer interface and returns the original error. +func (e *Error) Cause() error { + return e.Err +} + +// Error implements the error interface and returns the error string. +func (e *Error) Error() string { + return e.Err.Error() +} + +// StatusCode implements the StatusCoder interface and returns the HTTP response +// code. +func (e *Error) StatusCode() int { + return e.Status +} + +// Message returns a user friendly error, if one is set. +func (e *Error) Message() string { + if len(e.Msg) > 0 { + return e.Msg + } + return e.Err.Error() +} + +// Wrap returns an error annotating err with a stack trace at the point Wrap is +// called, and the supplied message. If err is nil, Wrap returns nil. +func Wrap(status int, e error, m string, opts ...Option) error { + if e == nil { + return nil + } + if err, ok := e.(*Error); ok { + err.Err = errors.Wrap(err.Err, m) + e = err + } else { + e = errors.Wrap(e, m) + } + return StatusCodeError(status, e, opts...) +} + +// Wrapf returns an error annotating err with a stack trace at the point Wrap is +// called, and the supplied message. If err is nil, Wrap returns nil. +func Wrapf(status int, e error, format string, args ...interface{}) error { + if e == nil { + return nil + } + var opts []Option + for i, arg := range args { + // Once we find the first Option, assume that all further arguments are Options. + if _, ok := arg.(Option); ok { + for _, a := range args[i:] { + // Ignore any arguments after the first Option that are not Options. + if opt, ok := a.(Option); ok { + opts = append(opts, opt) + } + } + args = args[:i] + break + } + } + if err, ok := e.(*Error); ok { + err.Err = errors.Wrapf(err.Err, format, args...) + e = err + } else { + e = errors.Wrapf(e, format, args...) + } + return StatusCodeError(status, e, opts...) +} + +// MarshalJSON implements json.Marshaller interface for the Error struct. +func (e *Error) MarshalJSON() ([]byte, error) { + var msg string + if len(e.Msg) > 0 { + msg = e.Msg + } else { + msg = http.StatusText(e.Status) + } + return json.Marshal(&ErrorResponse{Status: e.Status, Message: msg}) +} + +// UnmarshalJSON implements json.Unmarshaler interface for the Error struct. +func (e *Error) UnmarshalJSON(data []byte) error { + var er ErrorResponse + if err := json.Unmarshal(data, &er); err != nil { + return err + } + e.Status = er.Status + e.Err = fmt.Errorf(er.Message) + return nil +} + +// Format implements the fmt.Formatter interface. +func (e *Error) Format(f fmt.State, c rune) { + if err, ok := e.Err.(fmt.Formatter); ok { + err.Format(f, c) + return + } + fmt.Fprint(f, e.Err.Error()) +} + +// Messenger is a friendly message interface that errors can implement. +type Messenger interface { + Message() string +} + +// StatusCodeError selects the proper error based on the status code. +func StatusCodeError(code int, e error, opts ...Option) error { + switch code { + case http.StatusBadRequest: + return BadRequest(e, opts...) + case http.StatusUnauthorized: + return Unauthorized(e, opts...) + case http.StatusForbidden: + return Forbidden(e, opts...) + case http.StatusInternalServerError: + return InternalServerError(e, opts...) + case http.StatusNotImplemented: + return NotImplemented(e, opts...) + default: + return UnexpectedError(code, e, opts...) + } +} + +var seeLogs = "Please see the certificate authority logs for more info." + +// InternalServerError returns a 500 error with the given error. +func InternalServerError(err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The certificate authority encountered an Internal Server Error. "+seeLogs)) + } + return New(http.StatusInternalServerError, err, opts...) +} + +// NotImplemented returns a 501 error with the given error. +func NotImplemented(err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The requested method is not implemented by the certificate authority. "+seeLogs)) + } + return New(http.StatusNotImplemented, err, opts...) +} + +// BadRequest returns an 400 error with the given error. +func BadRequest(err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The request could not be completed due to being poorly formatted or "+ + "missing critical data. "+seeLogs)) + } + return New(http.StatusBadRequest, err, opts...) +} + +// Unauthorized returns an 401 error with the given error. +func Unauthorized(err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The request lacked necessary authorization to be completed. "+seeLogs)) + } + return New(http.StatusUnauthorized, err, opts...) +} + +// Forbidden returns an 403 error with the given error. +func Forbidden(err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The request was Forbidden by the certificate authority. "+seeLogs)) + } + return New(http.StatusForbidden, err, opts...) +} + +// NotFound returns an 404 error with the given error. +func NotFound(err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The requested resource could not be found. "+seeLogs)) + } + return New(http.StatusNotFound, err, opts...) +} + +// UnexpectedError will be used when the certificate authority makes an outgoing +// request and receives an unhandled status code. +func UnexpectedError(code int, err error, opts ...Option) error { + if len(opts) == 0 { + opts = append(opts, WithMessage("The certificate authority received an "+ + "unexpected HTTP status code - '%d'. "+seeLogs, code)) + } + return New(code, err, opts...) +} diff --git a/go.mod b/go.mod index b8311236..c80084af 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/Masterminds/sprig/v3 v3.0.0 github.com/go-chi/chi v4.0.2+incompatible + github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a // indirect github.com/newrelic/go-agent v2.15.0+incompatible github.com/pkg/errors v0.8.1 github.com/rs/xid v1.2.1 @@ -18,4 +19,4 @@ require ( gopkg.in/square/go-jose.v2 v2.4.0 ) -//replace github.com/smallstep/cli => ../cli +replace github.com/smallstep/cli => ../cli diff --git a/go.sum b/go.sum index 9e9246ba..2f65e71d 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,8 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10= github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= +github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= @@ -177,6 +179,8 @@ golang.org/x/sys v0.0.0-20190424175732-18eb32c0e2f0/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c=