diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index c4779ea3..167d371a 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -9,12 +9,14 @@ import ( "encoding/asn1" "encoding/json" "net" + "net/http" "net/url" "reflect" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" ) @@ -83,19 +85,19 @@ type emailOnlyIdentity string func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error { switch { case len(req.DNSNames) > 0: - return errors.New("certificate request cannot contain DNS names") + return errs.BadRequest("certificate request cannot contain DNS names") case len(req.IPAddresses) > 0: - return errors.New("certificate request cannot contain IP addresses") + return errs.BadRequest("certificate request cannot contain IP addresses") case len(req.URIs) > 0: - return errors.New("certificate request cannot contain URIs") + return errs.BadRequest("certificate request cannot contain URIs") case len(req.EmailAddresses) == 0: - return errors.New("certificate request does not contain any email address") + return errs.BadRequest("certificate request does not contain any email address") case len(req.EmailAddresses) > 1: - return errors.New("certificate request contains too many email addresses") + return errs.BadRequest("certificate request contains too many email addresses") case req.EmailAddresses[0] == "": - return errors.New("certificate request cannot contain an empty email address") + return errs.BadRequest("certificate request cannot contain an empty email address") case req.EmailAddresses[0] != string(e): - return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e) + return errs.BadRequest("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e) default: return nil } @@ -108,12 +110,13 @@ type defaultPublicKeyValidator struct{} func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error { switch k := req.PublicKey.(type) { case *rsa.PublicKey: - if k.Size() < 256 { - return errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)") + if k.Size() < keyutil.MinRSAKeyBytes { + return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)", + 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) } case *ecdsa.PublicKey, ed25519.PublicKey: default: - return errors.Errorf("unrecognized public key of type '%T' in CSR", k) + return errs.BadRequest("certificate request key of type '%T' is not supported", k) } return nil } @@ -139,11 +142,12 @@ func (v publicKeyMinimumLengthValidator) Valid(req *x509.CertificateRequest) err case *rsa.PublicKey: minimumLengthInBytes := v.length / 8 if k.Size() < minimumLengthInBytes { - return errors.Errorf("rsa key in CSR must be at least %d bits (%d bytes)", v.length, minimumLengthInBytes) + return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)", + v.length, minimumLengthInBytes) } case *ecdsa.PublicKey, ed25519.PublicKey: default: - return errors.Errorf("unrecognized public key of type '%T' in CSR", k) + return errs.BadRequest("certificate request key of type '%T' is not supported", k) } return nil } @@ -158,7 +162,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error { return nil } if req.Subject.CommonName != string(v) { - return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v) + return errs.BadRequest("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) } return nil } @@ -176,7 +180,7 @@ func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error { return nil } } - return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v) + return errs.BadRequest("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) } // dnsNamesValidator validates the DNS names SAN of a certificate request. @@ -197,7 +201,7 @@ func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error { got[s] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) + return errs.BadRequest("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) } return nil } @@ -220,7 +224,7 @@ func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error { got[ip.String()] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v) + return errs.BadRequest("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v) } return nil } @@ -243,7 +247,7 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error { got[s] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("certificate request does not contain the valid Email Addresses - got %v, want %v", req.EmailAddresses, v) + return errs.BadRequest("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v) } return nil } @@ -266,7 +270,7 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error { got[u.String()] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("URIs claim failed - got %v, want %v", req.URIs, v) + return errs.BadRequest("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v) } return nil } @@ -334,15 +338,15 @@ func (v profileLimitDuration) Modify(cert *x509.Certificate, so SignOptions) err backdate = -1 * so.Backdate } if notBefore.Before(v.notBefore) { - return errors.Errorf("requested certificate notBefore (%s) is before "+ - "the active validity window of the provisioning credential (%s)", + return errs.Forbidden( + "requested certificate notBefore (%s) is before the active validity window of the provisioning credential (%s)", notBefore, v.notBefore) } notAfter := so.NotAfter.RelativeTime(notBefore) if notAfter.After(v.notAfter) { - return errors.Errorf("requested certificate notAfter (%s) is after "+ - "the expiration of the provisioning credential (%s)", + return errs.Forbidden( + "requested certificate notAfter (%s) is after the expiration of the provisioning credential (%s)", notAfter, v.notAfter) } if notAfter.IsZero() { @@ -422,16 +426,15 @@ func newForceCNOption(forceCN bool) *forceCNOption { func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { if !o.ForceCN { - // Forcing CN is disabled, do nothing to certificate return nil } + // Force the common name to be the first DNS if not provided. if cert.Subject.CommonName == "" { - if len(cert.DNSNames) > 0 { - cert.Subject.CommonName = cert.DNSNames[0] - } else { - return errors.New("Cannot force CN, DNSNames is empty") + if len(cert.DNSNames) == 0 { + return errs.Forbidden("cannot force common name, DNS names is empty") } + cert.Subject.CommonName = cert.DNSNames[0] } return nil @@ -456,7 +459,7 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValue func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) if err != nil { - return err + return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } // Prepend the provisioner extension. In the auth.Sign code we will // force the resulting certificate to only have one extension, the @@ -477,7 +480,7 @@ func createProvisionerExtension(typ int, name, credentialID string, keyValuePair KeyValuePairs: keyValuePairs, }) if err != nil { - return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension") + return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension") } return pkix.Extension{ Id: stepOIDProvisioner,