Simplify statuscoder error generators.

pull/166/head^2
max furman 4 years ago
parent dccbdf3a90
commit 1cb8bb3ae1

@ -63,6 +63,7 @@ issues:
- declaration of "err" shadows declaration at line - declaration of "err" shadows declaration at line
- should have a package comment, unless it's in another file for this package - should have a package comment, unless it's in another file for this package
- error strings should not be capitalized or end with punctuation or a newline - error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args
# golangci.com configuration # golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration # https://github.com/golangci/golangci/wiki/Configuration
service: service:

@ -295,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := h.Authority.Root(sum)
if err != nil { if err != nil {
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI))) WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
} }
@ -314,13 +314,13 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r) cursor, limit, err := parseCursor(r)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
p, next, err := h.Authority.GetProvisioners(cursor, limit) p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &ProvisionersResponse{ JSON(w, &ProvisionersResponse{
@ -334,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid) key, err := h.Authority.GetEncryptedKey(kid)
if err != nil { if err != nil {
WriteError(w, errs.NotFound(err)) WriteError(w, errs.NotFoundErr(err))
return return
} }
JSON(w, &ProvisionerKeyResponse{key}) JSON(w, &ProvisionerKeyResponse{key})
@ -344,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := h.Authority.GetRoots()
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
@ -362,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := h.Authority.GetFederation()
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

@ -915,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, errs.Forbidden(fmt.Errorf("an error")), http.StatusForbidden}, {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
@ -1010,10 +1010,10 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expectedError400 := errs.BadRequest(errors.New("force")) expectedError400 := errs.BadRequest("force")
expectedError400Bytes, err := json.Marshal(expectedError400) expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err) assert.FatalError(t, err)
expectedError500 := errs.InternalServerError(errors.New("force")) expectedError500 := errs.InternalServer("force")
expectedError500Bytes, err := json.Marshal(expectedError500) expectedError500Bytes, err := json.Marshal(expectedError500)
assert.FatalError(t, err) assert.FatalError(t, err)
for _, tt := range tests { for _, tt := range tests {
@ -1082,7 +1082,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
} }
expected := []byte(`{"key":"` + privKey + `"}`) expected := []byte(`{"key":"` + privKey + `"}`)
expectedError404 := errs.NotFound(errors.New("force")) expectedError404 := errs.NotFound("force")
expectedError404Bytes, err := json.Marshal(expectedError404) expectedError404Bytes, err := json.Marshal(expectedError404)
assert.FatalError(t, err) assert.FatalError(t, err)

@ -3,7 +3,6 @@ package api
import ( import (
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -11,7 +10,7 @@ import (
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing peer certificate"))) WriteError(w, errs.BadRequest("missing peer certificate"))
return return
} }
@ -22,7 +21,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
var caPEM Certificate var caPEM Certificate
if len(certChainPEM) > 0 { if len(certChainPEM) > 1 {
caPEM = certChainPEM[1] caPEM = certChainPEM[1]
} }

@ -4,7 +4,6 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -30,13 +29,13 @@ type RevokeRequest struct {
// or an error if something is wrong. // or an error if something is wrong.
func (r *RevokeRequest) Validate() (err error) { func (r *RevokeRequest) Validate() (err error) {
if r.Serial == "" { if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial")) return errs.BadRequest("missing serial")
} }
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds")) return errs.BadRequest("reasonCode out of bounds")
} }
if !r.Passive { if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented")) return errs.NotImplemented("non-passive revocation not implemented")
} }
return return
@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 { if len(body.OTT) > 0 {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate"))) WriteError(w, errs.BadRequest("missing ott or peer certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body"))) WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.
@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
return errs.InternalServerError(errors.New("force")) return errs.InternalServer("force")
}, },
}, },
} }

@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
@ -22,13 +21,13 @@ type SignRequest struct {
// or an error if something is wrong. // or an error if something is wrong.
func (s *SignRequest) Validate() error { func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil { if s.CsrPEM.CertificateRequest == nil {
return errs.BadRequest(errors.New("missing csr")) return errs.BadRequest("missing csr")
} }
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.BadRequest(errors.Wrap(err, "invalid csr")) return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
} }
if s.OTT == "" { if s.OTT == "" {
return errs.BadRequest(errors.New("missing ott")) return errs.BadRequest("missing ott")
} }
return nil return nil
@ -49,7 +48,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
@ -66,18 +65,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT) signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
var caPEM Certificate var caPEM Certificate
if len(certChainPEM) > 0 { if len(certChainPEM) > 1 {
caPEM = certChainPEM[1] caPEM = certChainPEM[1]
} }
logCertificate(w, certChain[0]) logCertificate(w, certChain[0])

@ -249,19 +249,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return return
} }
@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -285,13 +285,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
addUserCertificate = &SSHCertificate{addUserCert} addUserCertificate = &SSHCertificate{addUserCert}
@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(cr, opts, signOpts...) certChain, err := h.Authority.Sign(cr, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)
@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots() keys, err := h.Authority.GetSSHRoots()
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found"))) WriteError(w, errs.NotFound("no keys found"))
return return
} }
@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation() keys, err := h.Authority.GetSSHFederation()
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found"))) WriteError(w, errs.NotFound("no keys found"))
return return
} }
@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data) ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
config.HostTemplates = ts config.HostTemplates = ts
default: default:
WriteError(w, errs.InternalServerError(errors.New("it should hot get here"))) WriteError(w, errs.InternalServer("it should hot get here"))
return return
} }
@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHCheckPrincipalResponse{ JSON(w, &SSHCheckPrincipalResponse{
@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(cert) hosts, err := h.Authority.GetSSHHosts(cert)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHGetHostsResponse{ JSON(w, &SSHGetHostsResponse{
@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }

@ -40,42 +40,42 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return return
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
} }
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...) newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
identity, err := h.renewIdentityCertificate(r) identity, err := h.renewIdentityCertificate(r)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

@ -36,36 +36,36 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT) _, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
} }
newCert, err := h.Authority.RenewSSH(oldCert) newCert, err := h.Authority.RenewSSH(oldCert)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
identity, err := h.renewIdentityCertificate(r) identity, err := h.renewIdentityCertificate(r)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

@ -4,7 +4,6 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -30,16 +29,16 @@ type SSHRevokeRequest struct {
// or an error if something is wrong. // or an error if something is wrong.
func (r *SSHRevokeRequest) Validate() (err error) { func (r *SSHRevokeRequest) Validate() (err error) {
if r.Serial == "" { if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial")) return errs.BadRequest("missing serial")
} }
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds")) return errs.BadRequest("reasonCode out of bounds")
} }
if !r.Passive { if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented")) return errs.NotImplemented("non-passive revocation not implemented")
} }
if len(r.OTT) == 0 { if len(r.OTT) == 0 {
return errs.BadRequest(errors.New("missing ott")) return errs.BadRequest("missing ott")
} }
return return
} }
@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
@ -71,13 +70,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

@ -6,7 +6,6 @@ import (
"log" "log"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
// pointed by v. // pointed by v.
func ReadJSON(r io.Reader, v interface{}) error { func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil { if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequest(errors.Wrap(err, "error decoding json")) return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
} }
return nil return nil
} }

@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
@ -58,15 +57,15 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// This check is meant as a stopgap solution to the current lack of a persistence layer. // This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, errs.Unauthorized(errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority")) return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
} }
} }
// This method will also validate the audiences for JWK provisioners. // This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok { if !ok {
return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: provisioner "+ return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))) "not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
} }
// Store the token to protect against reuse unless it's skipped. // Store the token to protect against reuse unless it's skipped.
@ -78,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
"authority.authorizeToken: failed when attempting to store token") "authority.authorizeToken: failed when attempting to store token")
} }
if !ok { if !ok {
return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: token already used")) return nil, errs.Unauthorized("authority.authorizeToken: token already used")
} }
} }
} }
@ -89,7 +88,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// Authorize grabs the method from the context and authorizes the request by // Authorize grabs the method from the context and authorizes the request by
// validating the one-time-token. // validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) { func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
var opts = []errs.Option{errs.WithKeyVal("token", token)} var opts = []interface{}{errs.WithKeyVal("token", token)}
switch m := provisioner.MethodFromContext(ctx); m { switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod: case provisioner.SignMethod:
@ -99,13 +98,13 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...) return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SSHSignMethod: case provisioner.SSHSignMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
} }
_, err := a.authorizeSSHSign(ctx, token) _, err := a.authorizeSSHSign(ctx, token)
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.SSHRenewMethod: case provisioner.SSHRenewMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
} }
_, err := a.authorizeSSHRenew(ctx, token) _, err := a.authorizeSSHRenew(ctx, token)
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
@ -113,12 +112,12 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...) return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SSHRekeyMethod: case provisioner.SSHRekeyMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
} }
_, signOpts, err := a.authorizeSSHRekey(ctx, token) _, signOpts, err := a.authorizeSSHRekey(ctx, token)
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
default: default:
return nil, errs.InternalServerError(errors.Errorf("authority.Authorize; method %d is not supported", m), opts...) return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...)
} }
} }
@ -165,7 +164,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
// //
// TODO(mariano): should we authorize by default? // TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
var opts = []errs.Option{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
// Check the passive revocation table. // Check the passive revocation table.
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String()) isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
@ -173,12 +172,12 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
} }
if isRevoked { if isRevoked {
return errs.Unauthorized(errors.New("authority.authorizeRenew: certificate has been revoked"), opts...) return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
} }
p, ok := a.provisioners.LoadByCertificate(cert) p, ok := a.provisioners.LoadByCertificate(cert)
if !ok { if !ok {
return errs.Unauthorized(errors.New("authority.authorizeRenew: provisioner not found"), opts...) return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
} }
if err := p.AuthorizeRenew(context.Background(), cert); err != nil { if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)

@ -180,7 +180,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
} }
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
_, err = _a.authorizeToken(context.TODO(), raw) _, err = _a.authorizeToken(context.Background(), raw)
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
@ -268,7 +268,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
p, err := tc.auth.authorizeToken(context.TODO(), tc.token) p, err := tc.auth.authorizeToken(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
@ -355,7 +355,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
if err := tc.auth.authorizeRevoke(context.TODO(), tc.token); err != nil { if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")

@ -80,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())) return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
} }
return nil return nil
} }

@ -306,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())) return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -353,7 +353,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token") return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, errs.InternalServerError(errors.New("aws.authorizeToken; error parsing token, header is missing")) return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing")
} }
var unsafeClaims awsPayload var unsafeClaims awsPayload
@ -378,13 +378,13 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
switch { switch {
case doc.AccountID == "": case doc.AccountID == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document accountId cannot be empty")) return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty")
case doc.InstanceID == "": case doc.InstanceID == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty")) return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty")
case doc.PrivateIP == "": case doc.PrivateIP == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty")) return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty")
case doc.Region == "": case doc.Region == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document region cannot be empty")) return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -399,7 +399,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) { if !matchesAudience(payload.Audience, p.audiences.Sign) {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)")) return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
} }
// Validate subject, it has to be known if disableCustomSANs is enabled // Validate subject, it has to be known if disableCustomSANs is enabled
@ -407,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
if payload.Subject != doc.InstanceID && if payload.Subject != doc.InstanceID &&
payload.Subject != doc.PrivateIP && payload.Subject != doc.PrivateIP &&
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)")) return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
} }
} }
@ -421,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid")) return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid")
} }
} }
// validate instance age // validate instance age
if d := p.InstanceAge.Value(); d > 0 { if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(doc.PendingTime) > d { if now.Sub(doc.PendingTime) > d {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document pendingTime is too old")) return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old")
} }
} }
@ -439,7 +439,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())) return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
if err != nil { if err != nil {
@ -462,7 +462,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}, },
} }
// Validate user options // Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options // Set defaults if not given as user options
signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
@ -474,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

@ -704,7 +704,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)

@ -210,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) {
return nil return nil
} }
// authorizeToken returs the claims, name, group, error. // authorizeToken returns the claims, name, group, error.
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) { func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token missing header")) return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
} }
var found bool var found bool
@ -230,7 +230,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
} }
} }
if !found { if !found {
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; cannot validate azure token")) return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
} }
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
@ -243,12 +243,12 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
// Validate TenantID // Validate TenantID
if claims.TenantID != p.TenantID { if claims.TenantID != p.TenantID {
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")) return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
} }
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 { if len(re) != 4 {
return nil, "", "", errs.Unauthorized(errors.Errorf("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)) return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
} }
group, name := re[2], re[3] group, name := re[2], re[3]
return &claims, name, group, nil return &claims, name, group, nil
@ -272,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.New("azure.AuthorizeSign; azure token validation failed - invalid resource group")) return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group")
} }
} }
@ -302,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())) return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -310,7 +310,7 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())) return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
} }
_, name, _, err := p.authorizeToken(token) _, name, _, err := p.authorizeToken(token)
@ -328,7 +328,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
Principals: []string{name}, Principals: []string{name},
} }
// Validate user options // Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options // Set defaults if not given as user options
signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
@ -340,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

@ -488,7 +488,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)

@ -243,7 +243,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())) return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -264,7 +264,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token") return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; error parsing gcp token - header is missing")) return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing")
} }
var found bool var found bool
@ -278,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.Errorf("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)) return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -293,7 +293,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) { if !matchesAudience(claims.Audience, p.audiences.Sign) {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")) return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
} }
// validate subject (service account) // validate subject (service account)
@ -306,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim")) return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim")
} }
} }
@ -320,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid project id")) return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id")
} }
} }
// validate instance age // validate instance age
if d := p.InstanceAge.Value(); d > 0 { if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d { if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")) return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")
} }
} }
switch { switch {
case claims.Google.ComputeEngine.InstanceID == "": case claims.Google.ComputeEngine.InstanceID == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")) return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")
case claims.Google.ComputeEngine.InstanceName == "": case claims.Google.ComputeEngine.InstanceName == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")) return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")
case claims.Google.ComputeEngine.ProjectID == "": case claims.Google.ComputeEngine.ProjectID == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")) return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")
case claims.Google.ComputeEngine.Zone == "": case claims.Google.ComputeEngine.Zone == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")) return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")
} }
return &claims, nil return &claims, nil
@ -348,7 +348,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())) return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
if err != nil { if err != nil {
@ -371,7 +371,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}, },
} }
// Validate user options // Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options // Set defaults if not given as user options
signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
@ -383,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

@ -680,7 +680,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)

@ -120,12 +120,12 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errs.Unauthorized(errors.Errorf("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s", return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
audiences, claims.Audience)) audiences, claims.Audience)
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("jwk.authorizeToken; jwk token subject cannot be empty")) return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty")
} }
return &claims, nil return &claims, nil
@ -173,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())) return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -181,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
} }
if claims.Step == nil || claims.Step.SSH == nil { if claims.Step == nil || claims.Step.SSH == nil {
return nil, errs.Unauthorized(errors.New("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")
} }
opts := claims.Step.SSH opts := claims.Step.SSH
signOptions := []SignOption{ signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token // validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts), sshCertOptionsValidator(*opts),
} }
t := now() t := now()
@ -231,9 +231,9 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

@ -222,7 +222,7 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -337,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)

@ -149,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
claims k8sSAPayload claims k8sSAPayload
) )
if p.pubKeys == nil { if p.pubKeys == nil {
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")) return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")
/* NOTE: We plan to support the TokenReview API in a future release. /* NOTE: We plan to support the TokenReview API in a future release.
Below is some code that should be useful when we prioritize Below is some code that should be useful when we prioritize
this integration. this integration.
@ -177,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
} }
} }
if !valid { if !valid {
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")) return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -189,7 +189,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA token subject cannot be empty")) return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty")
} }
return &claims, nil return &claims, nil
@ -221,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())) return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -229,7 +229,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign validates an request for an SSH certificate. // AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())) return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
} }
if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil { if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
@ -246,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

@ -363,10 +363,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case sshCertDefaultsModifier: case sshCertDefaultsModifier:
assert.Equals(t, v.CertType, SSHUserCert) assert.Equals(t, v.CertType, SSHUserCert)
case *sshDefaultExtensionModifier: case *sshDefaultExtensionModifier:
case *sshCertificateValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertificateDefaultValidator: case *sshCertDefaultValidator:
case *sshDefaultDuration: case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.claimer)
default: default:

@ -14,8 +14,8 @@ func Test_noop(t *testing.T) {
assert.Equals(t, "noop", p.GetName()) assert.Equals(t, "noop", p.GetName())
assert.Equals(t, noopType, p.GetType()) assert.Equals(t, noopType, p.GetType())
assert.Equals(t, nil, p.Init(Config{})) assert.Equals(t, nil, p.Init(Config{}))
assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{})) assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo")) assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo"))
kid, key, ok := p.GetEncryptedKey() kid, key, ok := p.GetEncryptedKey()
assert.Equals(t, "", kid) assert.Equals(t, "", kid)

@ -195,12 +195,12 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
// Validate azp if present // Validate azp if present
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: invalid azp")) return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp")
} }
// Enforce an email claim // Enforce an email claim
if p.Email == "" { if p.Email == "" {
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email not found")) return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found")
} }
// Validate domains (case-insensitive) // Validate domains (case-insensitive)
@ -214,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
} }
} }
if !found { if !found {
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email is not allowed")) return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
} }
} }
@ -230,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
} }
} }
if !found { if !found {
return errs.Unauthorized(errors.New("validatePayload: oidc token payload validation failed: invalid group")) return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group")
} }
} }
@ -263,7 +263,7 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.New("oidc.AuthorizeToken; cannot validate oidc token")) return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token")
} }
if err := o.ValidatePayload(claims); err != nil { if err := o.ValidatePayload(claims); err != nil {
@ -286,7 +286,7 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
if o.IsAdmin(claims.Email) { if o.IsAdmin(claims.Email) {
return nil return nil
} }
return errs.Unauthorized(errors.New("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")) return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
@ -318,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() { if o.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())) return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
} }
return nil return nil
} }
@ -326,7 +326,7 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() { if !o.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())) return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
} }
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
@ -352,7 +352,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Non-admin users can only use principals returned by the identityFunc, and // Non-admin users can only use principals returned by the identityFunc, and
// can only sign user certificates. // can only sign user certificates.
if !o.IsAdmin(claims.Email) { if !o.IsAdmin(claims.Email) {
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
} }
// Default to a user certificate with usernames as principals if those options // Default to a user certificate with usernames as principals if those options
@ -367,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{o.claimer}, &sshCertValidityValidator{o.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }
@ -382,7 +382,7 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// Only admins can revoke certificates. // Only admins can revoke certificates.
if !o.IsAdmin(claims.Email) { if !o.IsAdmin(claims.Email) {
return errs.Unauthorized(errors.New("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")) return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
} }
return nil return nil
} }

@ -284,43 +284,43 @@ type base struct{}
// AuthorizeSign returns an unimplmented error. Provisioners should overwrite // AuthorizeSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing x509 Certificates. // this method if they will support authorizing tokens for signing x509 Certificates.
func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSign not implemented")) return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
} }
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking x509 Certificates. // this method if they will support authorizing tokens for revoking x509 Certificates.
func (b *base) AuthorizeRevoke(ctx context.Context, token string) error { func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
return errs.Unauthorized(errors.New("provisioner.AuthorizeRevoke not implemented")) return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
} }
// AuthorizeRenew returns an unimplmented error. Provisioners should overwrite // AuthorizeRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing x509 Certificates. // this method if they will support authorizing tokens for renewing x509 Certificates.
func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
return errs.Unauthorized(errors.New("provisioner.AuthorizeRenew not implemented")) return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
} }
// AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite // AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing SSH Certificates. // this method if they will support authorizing tokens for signing SSH Certificates.
func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHSign not implemented")) return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
} }
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking SSH Certificates. // this method if they will support authorizing tokens for revoking SSH Certificates.
func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRevoke not implemented")) return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
} }
// AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite // AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing SSH Certificates. // this method if they will support authorizing tokens for renewing SSH Certificates.
func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRenew not implemented")) return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
} }
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite // AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for rekeying SSH Certificates. // this method if they will support authorizing tokens for rekeying SSH Certificates.
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRekey not implemented")) return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
} }
// Identity is the type representing an externally supplied identity that is used // Identity is the type representing an externally supplied identity that is used

@ -19,29 +19,29 @@ const (
SSHHostCert = "host" SSHHostCert = "host"
) )
// SSHCertificateModifier is the interface used to change properties in an SSH // SSHCertModifier is the interface used to change properties in an SSH
// certificate. // certificate.
type SSHCertificateModifier interface { type SSHCertModifier interface {
SignOption SignOption
Modify(cert *ssh.Certificate) error Modify(cert *ssh.Certificate) error
} }
// SSHCertificateOptionModifier is the interface used to add custom options used // SSHCertOptionModifier is the interface used to add custom options used
// to modify the SSH certificate. // to modify the SSH certificate.
type SSHCertificateOptionModifier interface { type SSHCertOptionModifier interface {
SignOption SignOption
Option(o SSHOptions) SSHCertificateModifier Option(o SSHOptions) SSHCertModifier
} }
// SSHCertificateValidator is the interface used to validate an SSH certificate. // SSHCertValidator is the interface used to validate an SSH certificate.
type SSHCertificateValidator interface { type SSHCertValidator interface {
SignOption SignOption
Valid(cert *ssh.Certificate) error Valid(cert *ssh.Certificate) error
} }
// SSHCertificateOptionsValidator is the interface used to validate the custom // SSHCertOptionsValidator is the interface used to validate the custom
// options used to modify the SSH certificate. // options used to modify the SSH certificate.
type SSHCertificateOptionsValidator interface { type SSHCertOptionsValidator interface {
SignOption SignOption
Valid(got SSHOptions) error Valid(got SSHOptions) error
} }
@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 {
return sshCertTypeUInt32(o.CertType) return sshCertTypeUInt32(o.CertType)
} }
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate. // Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
func (o SSHOptions) Modify(cert *ssh.Certificate) error { func (o SSHOptions) Modify(cert *ssh.Certificate) error {
switch o.CertType { switch o.CertType {
case "": // ignore case "": // ignore
@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error {
return nil return nil
} }
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the // sshCertPrincipalsModifier is an SSHCertModifier that sets the
// principals to the SSH certificate. // principals to the SSH certificate.
type sshCertPrincipalsModifier []string type sshCertPrincipalsModifier []string
@ -126,7 +126,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertKeyIDModifier is an SSHCertificateModifier that sets the given // sshCertKeyIDModifier is an SSHCertModifier that sets the given
// Key ID in the SSH certificate. // Key ID in the SSH certificate.
type sshCertKeyIDModifier string type sshCertKeyIDModifier string
@ -135,7 +135,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertTypeModifier is an SSHCertificateModifier that sets the // sshCertTypeModifier is an SSHCertModifier that sets the
// certificate type. // certificate type.
type sshCertTypeModifier string type sshCertTypeModifier string
@ -145,7 +145,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertValidAfterModifier is an SSHCertificateModifier that sets the // sshCertValidAfterModifier is an SSHCertModifier that sets the
// ValidAfter in the SSH certificate. // ValidAfter in the SSH certificate.
type sshCertValidAfterModifier uint64 type sshCertValidAfterModifier uint64
@ -154,7 +154,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertValidBeforeModifier is an SSHCertificateModifier that sets the // sshCertValidBeforeModifier is an SSHCertModifier that sets the
// ValidBefore in the SSH certificate. // ValidBefore in the SSH certificate.
type sshCertValidBeforeModifier uint64 type sshCertValidBeforeModifier uint64
@ -163,11 +163,11 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertDefaultsModifier implements a SSHCertificateModifier that // sshCertDefaultsModifier implements a SSHCertModifier that
// modifies the certificate with the given options if they are not set. // modifies the certificate with the given options if they are not set.
type sshCertDefaultsModifier SSHOptions type sshCertDefaultsModifier SSHOptions
// Modify implements the SSHCertificateModifier interface. // Modify implements the SSHCertModifier interface.
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error { func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
if cert.CertType == 0 { if cert.CertType == 0 {
cert.CertType = sshCertTypeUInt32(m.CertType) cert.CertType = sshCertTypeUInt32(m.CertType)
@ -184,7 +184,7 @@ func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets // sshDefaultExtensionModifier implements an SSHCertModifier that sets
// the default extensions in an SSH certificate. // the default extensions in an SSH certificate.
type sshDefaultExtensionModifier struct{} type sshDefaultExtensionModifier struct{}
@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
} }
} }
// sshDefaultDuration is an SSHCertificateModifier that sets the certificate // sshDefaultDuration is an SSHCertModifier that sets the certificate
// ValidAfter and ValidBefore if they have not been set. It will fail if a // ValidAfter and ValidBefore if they have not been set. It will fail if a
// CertType has not been set or is not valid. // CertType has not been set or is not valid.
type sshDefaultDuration struct { type sshDefaultDuration struct {
*Claimer *Claimer
} }
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier { func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier {
return sshModifierFunc(func(cert *ssh.Certificate) error { return sshModifierFunc(func(cert *ssh.Certificate) error {
d, err := m.DefaultSSHCertDuration(cert.CertType) d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil { if err != nil {
@ -248,7 +248,7 @@ type sshLimitDuration struct {
NotAfter time.Time NotAfter time.Time
} }
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier { func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier {
if m.NotAfter.IsZero() { if m.NotAfter.IsZero() {
defaultDuration := &sshDefaultDuration{m.Claimer} defaultDuration := &sshDefaultDuration{m.Claimer}
return defaultDuration.Option(o) return defaultDuration.Option(o)
@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
}) })
} }
// sshCertificateOptionsValidator validates the user SSHOptions with the ones // sshCertOptionsValidator validates the user SSHOptions with the ones
// usually present in the token. // usually present in the token.
type sshCertificateOptionsValidator SSHOptions type sshCertOptionsValidator SSHOptions
// Valid implements SSHCertificateOptionsValidator and returns nil if both // Valid implements SSHCertOptionsValidator and returns nil if both
// SSHOptions match. // SSHOptions match.
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error { func (v sshCertOptionsValidator) Valid(got SSHOptions) error {
want := SSHOptions(v) want := SSHOptions(v)
return want.match(got) return want.match(got)
} }
type sshCertificateValidityValidator struct { type sshCertValidityValidator struct {
*Claimer *Claimer
} }
func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return errors.New("ssh certificate validAfter cannot be 0")
@ -355,12 +355,12 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
} }
} }
// sshCertificateDefaultValidator implements a simple validator for all the // sshCertDefaultValidator implements a simple validator for all the
// fields in the SSH certificate. // fields in the SSH certificate.
type sshCertificateDefaultValidator struct{} type sshCertDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the necessary fields. // Valid returns an error if the given certificate does not contain the necessary fields.
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error { func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate) error {
switch { switch {
case len(cert.Nonce) == 0: case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty") return errors.New("ssh certificate nonce cannot be empty")

@ -489,12 +489,12 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
} }
} }
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) { func Test_sshCertDefaultValidator_Valid(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair() pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
sshPub, err := ssh.NewPublicKey(pub) sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertificateDefaultValidator{} v := sshCertDefaultValidator{}
tests := []struct { tests := []struct {
name string name string
cert *ssh.Certificate cert *ssh.Certificate
@ -670,10 +670,10 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
} }
} }
func Test_sshCertificateValidityValidator(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertificateValidityValidator{p.claimer} v := sshCertValidityValidator{p.claimer}
n := now() n := now()
tests := []struct { tests := []struct {
name string name string
@ -992,7 +992,7 @@ func Test_sshLimitDuration_Option(t *testing.T) {
name string name string
fields fields fields fields
args args args args
want SSHCertificateModifier want SSHCertModifier
}{ }{
// TODO: Add test cases. // TODO: Add test cases.
} }

@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
return nil, err return nil, err
} }
var mods []SSHCertificateModifier var mods []SSHCertModifier
var validators []SSHCertificateValidator var validators []SSHCertValidator
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// modify the ssh.Certificate // modify the ssh.Certificate
case SSHCertificateModifier: case SSHCertModifier:
mods = append(mods, o) mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions // modify the ssh.Certificate given the SSHOptions
case SSHCertificateOptionModifier: case SSHCertOptionModifier:
mods = append(mods, o.Option(opts)) mods = append(mods, o.Option(opts))
// validate the ssh.Certificate // validate the ssh.Certificate
case SSHCertificateValidator: case SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
// validate the given SSHOptions // validate the given SSHOptions
case SSHCertificateOptionsValidator: case SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil { if err := o.Valid(opts); err != nil {
return nil, err return nil, err
} }

@ -112,20 +112,20 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
return nil, errs.Wrap(http.StatusInternalServerError, err, return nil, errs.Wrap(http.StatusInternalServerError, err,
"sshpop.authorizeToken; error checking checking sshpop cert revocation") "sshpop.authorizeToken; error checking checking sshpop cert revocation")
} else if isRevoked { } else if isRevoked {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate is revoked")) return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
} }
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() n := time.Now()
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future")) return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
} }
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past")) return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
} }
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok { if !ok {
return nil, errs.InternalServerError(errors.New("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")) return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
} }
pubKey := sshCryptoPubKey.CryptoPublicKey() pubKey := sshCryptoPubKey.CryptoPublicKey()
@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
} }
} }
if !found { if !found {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")) return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")
} }
// Using the ssh certificates key to validate the claims accomplishes two // Using the ssh certificates key to validate the claims accomplishes two
@ -170,12 +170,12 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errs.Unauthorized(errors.Errorf("sshpop.authorizeToken; sshpop token has invalid audience "+ return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+
"claim (aud): expected %s, but got %s", audiences, claims.Audience)) "claim (aud): expected %s, but got %s", audiences, claims.Audience)
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty")) return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty")
} }
claims.sshCert = sshCert claims.sshCert = sshCert
@ -190,8 +190,8 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errs.BadRequest(errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject " + return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
"must be equivalent to sshpop certificate serial number")) "must be equivalent to sshpop certificate serial number")
} }
return nil return nil
} }
@ -204,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")) return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, nil return claims.sshCert, nil
@ -219,15 +219,15 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")) return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, []SignOption{ return claims.sshCert, []SignOption{
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
}, nil }, nil
} }

@ -564,8 +564,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertificateDefaultValidator: case *sshCertDefaultValidator:
case *sshCertificateValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))

@ -136,7 +136,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
leaf := verifiedChains[0][0] leaf := verifiedChains[0][0]
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")) return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
} }
// Using the leaf certificates key to validate the claims accomplishes two // Using the leaf certificates key to validate the claims accomplishes two
@ -160,12 +160,12 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errs.Unauthorized(errors.Errorf("x5c.authorizeToken; x5c token has invalid audience "+ return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", audiences, claims.Audience)) "claim (aud); expected %s, but got %s", audiences, claims.Audience)
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; x5c token subject cannot be empty")) return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty")
} }
// Save the verified chains on the x5c payload object. // Save the verified chains on the x5c payload object.
@ -213,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())) return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -221,7 +221,7 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.audiences.SSHSign)
@ -230,13 +230,13 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
} }
if claims.Step == nil || claims.Step.SSH == nil { if claims.Step == nil || claims.Step.SSH == nil {
return nil, errs.Unauthorized(errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")
} }
opts := claims.Step.SSH opts := claims.Step.SSH
signOptions := []SignOption{ signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token // validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts), sshCertOptionsValidator(*opts),
} }
// Add modifiers from custom claims // Add modifiers from custom claims
@ -272,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

@ -548,7 +548,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -594,7 +594,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.TODO(), nil); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -754,7 +754,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -768,7 +768,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
nw := now() nw := now()
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case sshCertificateOptionsValidator: case sshCertOptionsValidator:
tc.claims.Step.SSH.ValidAfter.t = time.Time{} tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{} tc.claims.Step.SSH.ValidBefore.t = time.Time{}
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH) assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
@ -787,10 +787,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case *sshLimitDuration: case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertificateValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator, case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
*sshCertificateDefaultValidator: *sshCertDefaultValidator:
case sshCertKeyIDValidator: case sshCertKeyIDValidator:
assert.Equals(t, string(v), "foo") assert.Equals(t, string(v), "foo")
default: default:

@ -2,18 +2,16 @@ package authority
import ( import (
"crypto/x509" "crypto/x509"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
) )
// GetEncryptedKey returns the JWE key corresponding to the given kid argument. // GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) { func (a *Authority) GetEncryptedKey(kid string) (string, error) {
key, ok := a.provisioners.LoadEncryptedKey(kid) key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok { if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid), return "", errs.NotFound("encrypted key with kid %s was not found", kid)
http.StatusNotFound, apiCtx{}}
} }
return key, nil return key, nil
} }
@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt) p, ok := a.provisioners.LoadByCertificate(crt)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"), return nil, errs.NotFound("provisioner not found")
http.StatusNotFound, apiCtx{}}
} }
return p, nil return p, nil
} }
@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
p, ok := a.provisioners.Load(id) p, ok := a.provisioners.Load(id)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"), return nil, errs.NotFound("provisioner not found")
http.StatusNotFound, apiCtx{}}
} }
return p, nil return p, nil
} }

@ -7,13 +7,15 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
) )
func TestGetEncryptedKey(t *testing.T) { func TestGetEncryptedKey(t *testing.T) {
type ek struct { type ek struct {
a *Authority a *Authority
kid string kid string
err *apiError err error
code int
} }
tests := map[string]func(t *testing.T) *ek{ tests := map[string]func(t *testing.T) *ek{
"ok": func(t *testing.T) *ek { "ok": func(t *testing.T) *ek {
@ -32,10 +34,10 @@ func TestGetEncryptedKey(t *testing.T) {
a, err := New(c) a, err := New(c)
assert.FatalError(t, err) assert.FatalError(t, err)
return &ek{ return &ek{
a: a, a: a,
kid: "foo", kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"), err: errors.New("encrypted key with kid foo was not found"),
http.StatusNotFound, apiCtx{}}, code: http.StatusNotFound,
} }
}, },
} }
@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid) ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch v := err.(type) { sc, ok := err.(errs.StatusCoder)
case *apiError: assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.code, tc.err.code) assert.HasPrefix(t, err.Error(), tc.err.Error())
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {

@ -2,23 +2,20 @@ package authority
import ( import (
"crypto/x509" "crypto/x509"
"net/http"
"github.com/pkg/errors" "github.com/smallstep/certificates/errs"
) )
// Root returns the certificate corresponding to the given SHA sum argument. // Root returns the certificate corresponding to the given SHA sum argument.
func (a *Authority) Root(sum string) (*x509.Certificate, error) { func (a *Authority) Root(sum string) (*x509.Certificate, error) {
val, ok := a.certificates.Load(sum) val, ok := a.certificates.Load(sum)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum), return nil, errs.NotFound("certificate with fingerprint %s was not found", sum)
http.StatusNotFound, apiCtx{}}
} }
crt, ok := val.(*x509.Certificate) crt, ok := val.(*x509.Certificate)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), return nil, errs.InternalServer("stored value is not a *x509.Certificate")
http.StatusInternalServerError, apiCtx{}}
} }
return crt, nil return crt, nil
} }
@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
crt, ok := v.(*x509.Certificate) crt, ok := v.(*x509.Certificate)
if !ok { if !ok {
federation = nil federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), err = errs.InternalServer("stored value is not a *x509.Certificate")
http.StatusInternalServerError, apiCtx{}}
return false return false
} }
federation = append(federation, crt) federation = append(federation, crt)

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
) )
@ -16,12 +17,13 @@ func TestRoot(t *testing.T) {
a.certificates.Store("invaliddata", "a string") // invalid cert for testing a.certificates.Store("invaliddata", "a string") // invalid cert for testing
tests := map[string]struct { tests := map[string]struct {
sum string sum string
err *apiError err error
code int
}{ }{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}}, "not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}}, "invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK},
} }
for name, tc := range tests { for name, tc := range tests {
@ -29,14 +31,10 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum) crt, err := a.Root(tc.sum)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch v := err.(type) { sc, ok := err.(errs.StatusCoder)
case *apiError: assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.code, tc.err.code) assert.HasPrefix(t, err.Error(), tc.err.Error())
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {

@ -122,7 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
// GetSSHConfig returns rendered templates for clients (user) or servers (host). // GetSSHConfig returns rendered templates for clients (user) or servers (host).
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
return nil, errs.NotFound(errors.New("getSSHConfig: ssh is not configured")) return nil, errs.NotFound("getSSHConfig: ssh is not configured")
} }
var ts []templates.Template var ts []templates.Template
@ -136,7 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
ts = a.config.Templates.SSH.Host ts = a.config.Templates.SSH.Host
} }
default: default:
return nil, errs.BadRequest(errors.Errorf("getSSHConfig: type %s is not valid", typ)) return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
} }
// Merge user and default data // Merge user and default data
@ -177,13 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
} }
return nil, nil return nil, nil
} }
return nil, errs.NotFound(errors.New("authority.GetSSHBastion; ssh is not configured")) return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured")
} }
// SignSSH creates a signed SSH certificate with the given public key and options. // SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var mods []provisioner.SSHCertificateModifier var mods []provisioner.SSHCertModifier
var validators []provisioner.SSHCertificateValidator var validators []provisioner.SSHCertValidator
// Set backdate with the configured value // Set backdate with the configured value
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
@ -191,27 +191,27 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// modify the ssh.Certificate // modify the ssh.Certificate
case provisioner.SSHCertificateModifier: case provisioner.SSHCertModifier:
mods = append(mods, o) mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions // modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertificateOptionModifier: case provisioner.SSHCertOptionModifier:
mods = append(mods, o.Option(opts)) mods = append(mods, o.Option(opts))
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertificateValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
// validate the given SSHOptions // validate the given SSHOptions
case provisioner.SSHCertificateOptionsValidator: case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil { if err := o.Valid(opts); err != nil {
return nil, errs.Forbidden(err) return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
default: default:
return nil, errs.InternalServerError(errors.Errorf("signSSH: invalid extra option type %T", o)) return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
} }
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, errs.InternalServerError(err) return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH")
} }
var serial uint64 var serial uint64
@ -228,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// Use opts to modify the certificate // Use opts to modify the certificate
if err := opts.Modify(cert); err != nil { if err := opts.Modify(cert); err != nil {
return nil, errs.Forbidden(err) return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
// Use provisioner modifiers // Use provisioner modifiers
for _, m := range mods { for _, m := range mods {
if err := m.Modify(cert); err != nil { if err := m.Modify(cert); err != nil {
return nil, errs.Forbidden(err) return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
} }
@ -243,16 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("signSSH: user certificate signing is not enabled")) return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("signSSH: host certificate signing is not enabled")) return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, errs.InternalServerError(errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType)) return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
} }
cert.SignatureKey = signer.PublicKey() cert.SignatureKey = signer.PublicKey()
@ -270,7 +270,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// User provisioners validators // User provisioners validators
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert); err != nil {
return nil, errs.Forbidden(err) return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
} }
@ -285,7 +285,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, errs.InternalServerError(err) return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
} }
var serial uint64 var serial uint64
@ -294,7 +294,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
} }
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest(errors.New("rewnewSSH: cannot renew certificate without validity period")) return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
} }
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
@ -321,16 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("renewSSH: user certificate signing is not enabled")) return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled")
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("renewSSH: host certificate signing is not enabled")) return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled")
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, errs.InternalServerError(errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType)) return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType)
} }
cert.SignatureKey = signer.PublicKey() cert.SignatureKey = signer.PublicKey()
@ -354,21 +354,21 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var validators []provisioner.SSHCertificateValidator var validators []provisioner.SSHCertValidator
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertificateValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
default: default:
return nil, errs.InternalServerError(errors.Errorf("rekeySSH; invalid extra option type %T", o)) return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o)
} }
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, errs.InternalServerError(err) return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH")
} }
var serial uint64 var serial uint64
@ -377,7 +377,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
} }
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest(errors.New("rekeySSH; cannot rekey certificate without validity period")) return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
} }
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
@ -404,16 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("rekeySSH; user certificate signing is not enabled")) return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled")
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("rekeySSH; host certificate signing is not enabled")) return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled")
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, errs.BadRequest(errors.Errorf("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)) return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
} }
cert.SignatureKey = signer.PublicKey() cert.SignatureKey = signer.PublicKey()
@ -431,7 +431,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
// Apply validators from provisioner.. // Apply validators from provisioner..
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert); err != nil {
return nil, errs.Forbidden(err) return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
} }
} }
@ -445,18 +445,18 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
// SignSSHAddUser signs a certificate that provisions a new user in a server. // SignSSHAddUser signs a certificate that provisions a new user in a server.
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("signSSHAddUser: user certificate signing is not enabled")) return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
} }
if subject.CertType != ssh.UserCert { if subject.CertType != ssh.UserCert {
return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate is not a user certificate")) return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate")
} }
if len(subject.ValidPrincipals) != 1 { if len(subject.ValidPrincipals) != 1 {
return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate does not have only one principal")) return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal")
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, errs.InternalServerError(err) return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser")
} }
var serial uint64 var serial uint64

@ -80,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
type sshTestOptionsModifier string type sshTestOptionsModifier string
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier { func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier {
return sshTestCertModifier(string(m)) return sshTestCertModifier(string(m))
} }
@ -492,12 +492,12 @@ func TestAuthority_CheckSSHHost(t *testing.T) {
want bool want bool
wantErr bool wantErr bool
}{ }{
{"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false}, {"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false},
{"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false}, {"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false},
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

@ -61,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
// Sign creates a signed certificate from a certificate signing request. // Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var ( var (
opts = []errs.Option{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{} certValidators = []provisioner.CertificateValidator{}
issIdentity = a.intermediateIdentity issIdentity = a.intermediateIdentity
@ -81,7 +81,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
case provisioner.ProfileModifier: case provisioner.ProfileModifier:
mods = append(mods, k.Option(signOpts)) mods = append(mods, k.Option(signOpts))
default: default:
return nil, errs.InternalServerError(errors.Errorf("authority.Sign; invalid extra option type %T", k), opts...) return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
} }
} }
@ -131,7 +131,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
// Renew creates a new Certificate identical to the old certificate, except // Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'. // with a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
opts := []errs.Option{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())} opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
// Check step provisioner extensions // Check step provisioner extensions
if err := a.authorizeRenew(oldCert); err != nil { if err := a.authorizeRenew(oldCert); err != nil {
@ -237,7 +237,7 @@ type RevokeOptions struct {
// //
// TODO: Add OCSP and CRL support. // TODO: Add OCSP and CRL support.
func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error { func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error {
opts := []errs.Option{ opts := []interface{}{
errs.WithKeyVal("serialNumber", revokeOpts.Serial), errs.WithKeyVal("serialNumber", revokeOpts.Serial),
errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode), errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode),
errs.WithKeyVal("reason", revokeOpts.Reason), errs.WithKeyVal("reason", revokeOpts.Reason),
@ -281,7 +281,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
var ok bool var ok bool
p, ok = a.provisioners.LoadByToken(token, &claims.Claims) p, ok = a.provisioners.LoadByToken(token, &claims.Claims)
if !ok { if !ok {
return errs.InternalServerError(errors.Errorf("authority.Revoke; provisioner not found"), opts...) return errs.InternalServer("authority.Revoke; provisioner not found", opts...)
} }
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT) rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
if err != nil { if err != nil {
@ -309,10 +309,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
case nil: case nil:
return nil return nil
case db.ErrNotImplemented: case db.ErrNotImplemented:
return errs.NotImplemented(errors.New("authority.Revoke; no persistence layer configured"), opts...) return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
case db.ErrAlreadyExists: case db.ErrAlreadyExists:
return errs.BadRequest(errors.Errorf("authority.Revoke; certificate with serial "+ return errs.BadRequest("authority.Revoke; certificate with serial "+
"number %s has already been revoked", rci.Serial), opts...) "number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
default: default:
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
} }

@ -553,7 +553,7 @@ retry:
// verify the sha256 // verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw) sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
return nil, errs.BadRequest(errors.New("client.Root; root certificate SHA256 fingerprint do not match")) return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
} }
return &root, nil return &root, nil
} }
@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
retry: retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u, return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
errs.WithMessage("Failed to perform POST request to %s", u)) []interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -974,8 +974,8 @@ retry:
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil { if err := readJSON(resp.Body, &check); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u, return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")) []interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")})
} }
return &check, nil return &check, nil
} }

@ -163,8 +163,8 @@ func TestClient_Version(t *testing.T) {
expectedErr error expectedErr error
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"500", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, {"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"404", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, {"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -214,7 +214,7 @@ func TestClient_Health(t *testing.T) {
expectedErr error expectedErr error
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"not ok", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, {"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -268,7 +268,7 @@ func TestClient_Root(t *testing.T) {
expectedErr error expectedErr error
}{ }{
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil}, {"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
{"not found", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, {"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -336,9 +336,9 @@ func TestClient_Sign(t *testing.T) {
expectedErr error expectedErr error
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -409,8 +409,8 @@ func TestClient_Revoke(t *testing.T) {
expectedErr error expectedErr error
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -483,9 +483,9 @@ func TestClient_Renew(t *testing.T) {
err error err error
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -533,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) {
ok := &api.ProvisionersResponse{ ok := &api.ProvisionersResponse{
Provisioners: provisioner.List{}, Provisioners: provisioner.List{},
} }
internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error")) internalServerError := errs.InternalServer("Internal Server Error")
tests := []struct { tests := []struct {
name string name string
@ -603,7 +603,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
err error err error
}{ }{
{"ok", "kid", ok, 200, false, nil}, {"ok", "kid", ok, 200, false, nil},
{"fail", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, {"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -665,8 +665,8 @@ func TestClient_Roots(t *testing.T) {
err error err error
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -724,7 +724,7 @@ func TestClient_Federation(t *testing.T) {
err error err error
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -786,7 +786,7 @@ func TestClient_SSHRoots(t *testing.T) {
err error err error
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"not found", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, {"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -869,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) {
func TestClient_RootFingerprint(t *testing.T) { func TestClient_RootFingerprint(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"} ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error")) nok := errs.InternalServer("Internal Server Error")
httpsServer := httptest.NewTLSServer(nil) httpsServer := httptest.NewTLSServer(nil)
defer httpsServer.Close() defer httpsServer.Close()
@ -947,7 +947,7 @@ func TestClient_SSHBastion(t *testing.T) {
}{ }{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)

@ -62,31 +62,6 @@ type Error struct {
Details map[string]interface{} Details map[string]interface{}
} }
// New returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func New(status int, err error, opts ...Option) error {
var (
e *Error
ok bool
)
if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
}
for _, o := range opts {
o(e)
}
return e
}
// ErrorResponse represents an error in JSON format. // ErrorResponse represents an error in JSON format.
type ErrorResponse struct { type ErrorResponse struct {
Status int `json:"status"` Status int `json:"status"`
@ -119,10 +94,11 @@ func (e *Error) Message() string {
// Wrap returns an error annotating err with a stack trace at the point Wrap is // Wrap returns an error annotating err with a stack trace at the point Wrap is
// called, and the supplied message. If err is nil, Wrap returns nil. // called, and the supplied message. If err is nil, Wrap returns nil.
func Wrap(status int, e error, m string, opts ...Option) error { func Wrap(status int, e error, m string, args ...interface{}) error {
if e == nil { if e == nil {
return nil return nil
} }
_, opts := splitOptionArgs(args)
if err, ok := e.(*Error); ok { if err, ok := e.(*Error); ok {
err.Err = errors.Wrap(err.Err, m) err.Err = errors.Wrap(err.Err, m)
e = err e = err
@ -138,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
if e == nil { if e == nil {
return nil return nil
} }
var opts []Option as, opts := splitOptionArgs(args)
for i, arg := range args {
// Once we find the first Option, assume that all further arguments are Options.
if _, ok := arg.(Option); ok {
for _, a := range args[i:] {
// Ignore any arguments after the first Option that are not Options.
if opt, ok := a.(Option); ok {
opts = append(opts, opt)
}
}
args = args[:i]
break
}
}
if err, ok := e.(*Error); ok { if err, ok := e.(*Error); ok {
err.Err = errors.Wrapf(err.Err, format, args...) err.Err = errors.Wrapf(err.Err, format, args...)
e = err e = err
} else { } else {
e = errors.Wrapf(e, format, args...) e = errors.Wrapf(e, format, as...)
} }
return StatusCodeError(status, e, opts...) return StatusCodeError(status, e, opts...)
} }
@ -201,24 +164,24 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error { func StatusCodeError(code int, e error, opts ...Option) error {
switch code { switch code {
case http.StatusBadRequest: case http.StatusBadRequest:
return BadRequest(e, opts...) return BadRequestErr(e, opts...)
case http.StatusUnauthorized: case http.StatusUnauthorized:
return Unauthorized(e, opts...) return UnauthorizedErr(e, opts...)
case http.StatusForbidden: case http.StatusForbidden:
return Forbidden(e, opts...) return ForbiddenErr(e, opts...)
case http.StatusInternalServerError: case http.StatusInternalServerError:
return InternalServerError(e, opts...) return InternalServerErr(e, opts...)
case http.StatusNotImplemented: case http.StatusNotImplemented:
return NotImplemented(e, opts...) return NotImplementedErr(e, opts...)
default: default:
return UnexpectedError(code, e, opts...) return UnexpectedErr(code, e, opts...)
} }
} }
var ( var (
seeLogs = "Please see the certificate authority logs for more info." seeLogs = "Please see the certificate authority logs for more info."
// BadRequestDefaultMsg 400 default msg // BadRequestDefaultMsg 400 default msg
BadRequestDefaultMsg = "The request could not be completed due to being poorly formatted or missing critical data. " + seeLogs BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs
// UnauthorizedDefaultMsg 401 default msg // UnauthorizedDefaultMsg 401 default msg
UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs
// ForbiddenDefaultMsg 403 default msg // ForbiddenDefaultMsg 403 default msg
@ -231,46 +194,142 @@ var (
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
) )
// InternalServerError returns a 500 error with the given error. // splitOptionArgs splits the variadic length args into string formatting args
func InternalServerError(err error, opts ...Option) error { // and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
indexOptionStart := -1
for i, a := range args {
if _, ok := a.(Option); ok {
indexOptionStart = i
break
}
}
if indexOptionStart < 0 {
return args, []Option{}
}
opts := []Option{}
// Ignore any non-Option args that come after the first Option.
for _, o := range args[indexOptionStart:] {
if opt, ok := o.(Option); ok {
opts = append(opts, opt)
}
}
return args[:indexOptionStart], opts
}
// NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error {
var (
e *Error
ok bool
)
if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
}
for _, o := range opts {
o(e)
}
return e
}
// Errorf creates a new error using the given format and status code.
func Errorf(code int, format string, args ...interface{}) error {
as, opts := splitOptionArgs(args)
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
for _, o := range opts {
o(e)
}
return e
}
// InternalServer creates a 500 error with the given format and arguments.
func InternalServer(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
return Errorf(http.StatusInternalServerError, format, args...)
}
// InternalServerErr returns a 500 error with the given error.
func InternalServerErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg)) opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
return New(http.StatusInternalServerError, err, opts...) return NewErr(http.StatusInternalServerError, err, opts...)
} }
// NotImplemented returns a 501 error with the given error. // NotImplemented creates a 501 error with the given format and arguments.
func NotImplemented(err error, opts ...Option) error { func NotImplemented(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotImplementedDefaultMsg))
return Errorf(http.StatusNotImplemented, format, args...)
}
// NotImplementedErr returns a 501 error with the given error.
func NotImplementedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg)) opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
return New(http.StatusNotImplemented, err, opts...) return NewErr(http.StatusNotImplemented, err, opts...)
}
// BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
return Errorf(http.StatusBadRequest, format, args...)
} }
// BadRequest returns an 400 error with the given error. // BadRequestErr returns an 400 error with the given error.
func BadRequest(err error, opts ...Option) error { func BadRequestErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return New(http.StatusBadRequest, err, opts...) return NewErr(http.StatusBadRequest, err, opts...)
}
// Unauthorized creates a 401 error with the given format and arguments.
func Unauthorized(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(UnauthorizedDefaultMsg))
return Errorf(http.StatusUnauthorized, format, args...)
} }
// Unauthorized returns an 401 error with the given error. // UnauthorizedErr returns an 401 error with the given error.
func Unauthorized(err error, opts ...Option) error { func UnauthorizedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg)) opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg))
return New(http.StatusUnauthorized, err, opts...) return NewErr(http.StatusUnauthorized, err, opts...)
} }
// Forbidden returns an 403 error with the given error. // Forbidden creates a 403 error with the given format and arguments.
func Forbidden(err error, opts ...Option) error { func Forbidden(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
return Errorf(http.StatusForbidden, format, args...)
}
// ForbiddenErr returns an 403 error with the given error.
func ForbiddenErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
return New(http.StatusForbidden, err, opts...) return NewErr(http.StatusForbidden, err, opts...)
}
// NotFound creates a 404 error with the given format and arguments.
func NotFound(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotFoundDefaultMsg))
return Errorf(http.StatusNotFound, format, args...)
} }
// NotFound returns an 404 error with the given error. // NotFoundErr returns an 404 error with the given error.
func NotFound(err error, opts ...Option) error { func NotFoundErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotFoundDefaultMsg)) opts = append(opts, withDefaultMessage(NotFoundDefaultMsg))
return New(http.StatusNotFound, err, opts...) return NewErr(http.StatusNotFound, err, opts...)
} }
// UnexpectedError will be used when the certificate authority makes an outgoing // UnexpectedErr will be used when the certificate authority makes an outgoing
// request and receives an unhandled status code. // request and receives an unhandled status code.
func UnexpectedError(code int, err error, opts ...Option) error { func UnexpectedErr(code int, err error, opts ...Option) error {
opts = append(opts, withDefaultMessage("The certificate authority received an "+ opts = append(opts, withDefaultMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code)) "unexpected HTTP status code - '%d'. "+seeLogs, code))
return New(code, err, opts...) return NewErr(code, err, opts...)
} }

Loading…
Cancel
Save