diff --git a/authority/authorize.go b/authority/authorize.go index 5108f567..4f64921b 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -276,6 +276,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} + isRevoked, err := a.IsRevoked(serial) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) @@ -283,7 +284,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { if isRevoked { return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } - p, ok := a.provisioners.LoadByCertificate(cert) if !ok { return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6d524a25..74f313e7 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -753,6 +753,7 @@ func TestAuthority_Authorize(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) { fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") + fooCrt.NotAfter = time.Now().Add(time.Hour) assert.FatalError(t, err) renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") @@ -822,7 +823,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { return &authorizeTest{ auth: a, cert: renewDisabledCrt, - err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"), code: http.StatusUnauthorized, } }, @@ -909,6 +910,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner. } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -917,6 +919,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 21958d36..913d0ace 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -6,7 +6,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" ) // ACME is the acme provisioner type, an entity that can authorize the ACME @@ -24,7 +23,7 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (p *ACME) DefaultTLSCertDuration() time.Duration { - return p.claimer.DefaultTLSCertDuration() + return p.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a JWK type. @@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign does not do any validation, because all validation is handled @@ -97,10 +92,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), newForceCNOption(p.ForceCN), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -118,8 +113,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index bd173f87..86e8a9a9 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) { } func TestACME_AuthorizeRenew(t *testing.T) { + now := time.Now() type test struct { p *ACME cert *x509.Certificate @@ -104,21 +105,27 @@ func TestACME_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateACME() assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -179,11 +186,11 @@ func TestACME_AuthorizeSign(t *testing.T) { case *forceCNOption: assert.Equals(t, v.ForceCN, tc.p.ForceCN) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index fdad7b4a..5f79d7d0 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -264,9 +264,8 @@ type AWS struct { IIDRoots string `json:"iidRoots,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *awsConfig - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Add default config if p.config, err = newAWSConfig(p.IIDRoots); err != nil { return err } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) // validate IMDS versions if len(p.IMDSVersions) == 0 { @@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) { } } - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -473,11 +470,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -486,10 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized @@ -664,7 +658,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { } // validate audiences with the defaults - if !matchesAudience(payload.Audience, p.audiences.Sign) { + if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") } @@ -704,7 +698,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -752,11 +746,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 0d2786db..2e684272 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -682,13 +682,13 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Len(t, 2, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), tt.args.cn) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: @@ -726,7 +726,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") @@ -747,7 +747,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -824,6 +824,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { } func TestAWS_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateAWS() assert.FatalError(t, err) p2, err := generateAWS() @@ -832,7 +833,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -845,8 +846,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 384617e0..d9654566 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -96,10 +96,10 @@ type Azure struct { DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *azureConfig oidcConfig openIDConfiguration keyStore *keyStore + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -203,27 +203,24 @@ func (p *Azure) Init(config Config) (err error) { case p.Audience == "": // use default audience p.Audience = azureDefaultAudience } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint - if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { - return err + if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { + return } if err := p.oidcConfig.Validate(); err != nil { return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) } // Get JWK key set if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { - return err + return } - return nil + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken returns the claims, name, group, subscription, identityObjectID, error. @@ -355,10 +352,10 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -367,15 +364,12 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) } @@ -420,11 +414,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 4ab734d5..c40d0f93 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -511,13 +511,13 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "virtualMachine") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -536,6 +536,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { } func TestAzure_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() @@ -544,7 +545,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -557,8 +558,14 @@ func TestAzure_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -595,7 +602,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("subject", "caURL") @@ -616,7 +623,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"virtualMachine"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index e46f4ce4..6070b640 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -88,10 +88,9 @@ type GCP struct { InstanceAge Duration `json:"instanceAge,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *gcpConfig keyStore *keyStore - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name should uniquely @@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { } // Init validates and initializes the GCP provisioner. -func (p *GCP) Init(config Config) error { - var err error +func (p *GCP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Initialize key store - p.keyStore, err = newKeyStore(p.config.CertsURL) - if err != nil { - return err + if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil { + return } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -269,19 +265,16 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized. @@ -328,7 +321,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } // validate audiences with the defaults - if !matchesAudience(claims.Audience, p.audiences.Sign) { + if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") } @@ -383,7 +376,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -431,11 +424,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 5f6f9bc7..2fc7fee0 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -554,13 +554,13 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Len(t, 4, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration()) case commonNameSliceValidator: assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -595,7 +595,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := generateGCPToken(p1.ServiceAccounts[0], @@ -622,7 +622,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -698,6 +698,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { } func TestGCP_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() @@ -706,7 +707,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -719,8 +720,14 @@ func TestGCP_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renewal-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 137915c8..764f5d7d 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -35,8 +35,9 @@ type JWK struct { EncryptedKey string `json:"encryptedKey,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + // claimer *Claimer + // audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -98,13 +99,8 @@ func (p *JWK) Init(config Config) (err error) { return errors.New("provisioner key cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -146,13 +142,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } @@ -179,12 +175,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators commonNameValidator(claims.Subject), defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -193,18 +189,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } @@ -261,11 +254,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil @@ -273,6 +266,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.SSHRevoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index deae8f7a..f6b2d93c 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -76,13 +76,13 @@ func TestJWK_Init(t *testing.T) { }, "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, } }, } @@ -305,13 +305,13 @@ func TestJWK_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "subject") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) default: @@ -325,6 +325,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } func TestJWK_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() @@ -333,7 +334,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -346,8 +347,14 @@ func TestJWK_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -373,7 +380,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p2.Claims = &Claims{EnableSSHCA: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p1.EncryptedKey) @@ -402,8 +409,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), @@ -485,8 +492,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { signer, err := generateJSONWebKey() assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index d260f5ec..557d571a 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -42,16 +42,15 @@ type k8sSAPayload struct { // entity trusted to make signature requests. type K8sSA struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - PubKeys []byte `json:"publicKeys,omitempty"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + PubKeys []byte `json:"publicKeys,omitempty"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` //kauthn kauthn.AuthenticationV1Interface pubKeys []interface{} + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) { p.kauthn = k8s.AuthenticationV1() */ - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } @@ -240,27 +234,24 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } @@ -282,11 +273,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Require type, key-id and principals in the SignSSHOptions. &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 176cdfd3..2f357ebe 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -179,6 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { } func TestK8sSA_AuthorizeRenew(t *testing.T) { + now := time.Now() type test struct { p *K8sSA cert *x509.Certificate @@ -192,21 +193,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -281,11 +288,11 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -313,7 +320,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p.Claims = &Claims{EnableSSHCA: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, @@ -365,11 +372,11 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertOptionsRequireValidator: assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshDefaultDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 72a275ff..11cff219 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -34,19 +34,18 @@ const ( // https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by // go.step.sm/crypto/x25519. type Nebula struct { - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - caPool *nebula.NebulaCAPool - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + caPool *nebula.NebulaCAPool + ctl *Controller } // Init verifies and initializes the Nebula provisioner. -func (p *Nebula) Init(config Config) error { +func (p *Nebula) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error { return errors.New("provisioner root(s) cannot be empty") } - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) if err != nil { return errs.InternalServer("failed to create ca pool: %v", err) } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // GetID returns the provisioner id. @@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) { // AuthorizeSign returns the list of SignOption for a Sign request. func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - crt, claims, err := p.authorizeToken(token, p.audiences.Sign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, err } @@ -154,7 +148,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // modifiers / withOptions newProvisionerExtensionOption(TypeNebula, p.Name, ""), profileLimitDuration{ - def: p.claimer.DefaultTLSCertDuration(), + def: p.ctl.Claimer.DefaultTLSCertDuration(), notBefore: crt.Details.NotBefore, notAfter: crt.Details.NotAfter, }, @@ -165,18 +159,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, IPs: crt.Details.Ips, }, defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // Currently the Nebula provisioner only grants host SSH certificates. func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, err } @@ -254,11 +248,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti return append(signOptions, templateOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, crt.Details.NotAfter}, + &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil @@ -266,7 +260,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti // AuthorizeRenew returns an error if the renewal is disabled. func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { + if p.ctl.Claimer.IsDisableRenewal() { return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName()) } return nil @@ -274,15 +268,15 @@ func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) erro // AuthorizeRevoke returns an error if the token is not valid. func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { - return p.validateToken(token, p.audiences.Revoke) + return p.validateToken(token, p.ctl.Audiences.Revoke) } // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { + if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil { return err } return nil diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index bc539af1..8f9afd9d 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) - t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) + t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) _, claims, err := parseToken(t1) if err != nil { t.Fatal(err) @@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) { ctx := context.TODO() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv) pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool @@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1"}, }, crt, priv) - okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) - okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv) + okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), }, crt, priv) - failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "user", }, crt, priv) - failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, @@ -584,12 +584,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -618,12 +618,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Provisioner with SSH disabled var bFalse bool @@ -657,7 +657,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -689,7 +689,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -726,20 +726,20 @@ func TestNebula_authorizeToken(t *testing.T) { t1 := now() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) - okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv) + okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, crt, priv) - okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) + okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv) // Token with errors - failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) - failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) + failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) - failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) // Not a nebula token jwk, err := generateJSONWebKey() @@ -761,7 +761,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.Sign[0]}, + Audience: []string{p.ctl.Audiences.Sign[0]}, } sshClaims := jose.Claims{ ID: "[REPLACEME]", @@ -770,7 +770,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.SSHSign[0]}, + Audience: []string{p.ctl.Audiences.SSHSign[0]}, } type args struct { @@ -785,14 +785,14 @@ func TestNebula_authorizeToken(t *testing.T) { want1 *jwtPayload wantErr bool }{ - {"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, SANs: []string{"10.1.0.1"}, }, false}, - {"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, }, false}, - {"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, Step: &stepPayload{ SSH: &SignSSHOptions{ @@ -802,16 +802,16 @@ func TestNebula_authorizeToken(t *testing.T) { }, }, }, false}, - {"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, }, false}, - {"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, - {"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, - {"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, - {"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, - {"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, - {"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, - {"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, + {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index ac1f2a25..1fc9bb4b 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -92,8 +92,7 @@ type OIDC struct { Options *Options `json:"options,omitempty"` configuration openIDConfiguration keyStore *keyStore - claimer *Claimer - getIdentityFunc GetIdentityFunc + ctl *Controller } func sanitizeEmail(email string) string { @@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) { } } - // Update claims with global ones - if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint u, err := url.Parse(o.ConfigurationEndpoint) if err != nil { @@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) { return err } - // Set the identity getter if it exists, otherwise use the default. - if config.GetIdentityFunc == nil { - o.getIdentityFunc = DefaultIdentityFunc - } else { - o.getIdentityFunc = config.GetIdentityFunc - } - return nil + o.ctl, err = NewController(o, o.Claims, config) + return } // ValidatePayload validates the given token payload. @@ -359,10 +348,10 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), - profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), + newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -371,15 +360,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if o.claimer.IsDisableRenewal() { - return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName()) - } - return nil + return o.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !o.claimer.IsSSHCAEnabled() { + if !o.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) } claims, err := o.authorizeToken(token) @@ -394,7 +380,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Get the identity using either the default identityFunc or one injected // externally. Note that the PreferredUsername might be empty. // TBD: Would preferred_username present a safety issue here? - iden, err := o.getIdentityFunc(ctx, o, claims.Email) + iden, err := o.ctl.GetIdentity(ctx, claims.Email) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } @@ -445,11 +431,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{o.claimer}, + &sshDefaultDuration{o.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{o.claimer}, + &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 7bf6ad7a..cfc789f9 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -332,11 +332,11 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") default: @@ -411,6 +411,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { } func TestOIDC_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() @@ -419,7 +420,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -432,8 +433,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -478,7 +485,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p6.Claims = &Claims{EnableSSHCA: &disable} - p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) + p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) assert.FatalError(t, err) // Update configuration endpoints and initialize @@ -494,10 +501,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } // Additional test needed for empty usernames and duplicate email and usernames @@ -527,8 +534,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 5d67762c..f4cffd78 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -11,28 +11,30 @@ import ( // SCEP provisioning flow type SCEP struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` ChallengePassword string `json:"challenge,omitempty"` Capabilities []string `json:"capabilities,omitempty"` + // IncludeRoot makes the provisioner return the CA root in addition to the // intermediate in the GetCACerts response IncludeRoot bool `json:"includeRoot,omitempty"` + // MinimumPublicKeyLength is the minimum length for public keys in CSRs MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - claimer *Claimer + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` secretChallengePassword string encryptionAlgorithm int + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (s *SCEP) DefaultTLSCertDuration() time.Duration { - return s.claimer.DefaultTLSCertDuration() + return s.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a SCEP type. @@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil { - return err - } - // Mask the actual challenge value, so it won't be marshaled s.secretChallengePassword = s.ChallengePassword s.ChallengePassword = "*** redacted ***" @@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? - return err + s.ctl, err = NewController(s, s.Claims, config) + return } // AuthorizeSign does not do any verification, because all verification is handled @@ -126,10 +124,10 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // modifiers / withOptions newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newForceCNOption(s.ForceCN), - profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), - newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), + newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), }, nil } diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index b59d6945..28a35639 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) - v := sshCertValidityValidator{p.claimer} + v := sshCertValidityValidator{p.ctl.Claimer} n := now() tests := []struct { name string @@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) { tests := map[string]func() test{ "fail/type-not-set": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), @@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/type-not-recognized": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ CertType: 4, ValidAfter: uint64(n.Unix()), @@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validAfter-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), @@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validBefore-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), @@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/no-limit": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/defaults": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/valid-requested-validBefore": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-after-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-before-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 3039d2a3..a7df38de 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -29,8 +29,7 @@ type SSHPOP struct { Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims,omitempty"` - claimer *Claimer - audiences Audiences + ctl *Controller sshPubKeys *SSHKeys } @@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a SSHPOP type. -func (p *SSHPOP) Init(config Config) error { +func (p *SSHPOP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error { return errors.New("provisioner public SSH validation keys cannot be empty") } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.sshPubKeys = config.SSHKeys - return nil + + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -186,7 +181,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // AuthorizeSSHRevoke validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { - claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } @@ -199,22 +194,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { // AuthorizeSSHRenew validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRenew) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } - - return claims.sshCert, nil - + return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert) } // AuthorizeSSHRekey validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRekey) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } @@ -225,7 +218,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, }, nil diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index da036864..715bf6de 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -38,6 +38,7 @@ func TestSSHPOP_Getters(t *testing.T) { } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -46,6 +47,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } @@ -455,7 +462,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index fe2678fc..ff8421f0 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -24,20 +24,22 @@ import ( ) var ( - defaultDisableRenewal = false - defaultEnableSSHCA = true - globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, + defaultDisableRenewal = false + defaultEnableRenewAfterExpiry = false + defaultEnableSSHCA = true + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, + DisableRenewal: &defaultDisableRenewal, + EnableRenewAfterExpiry: &defaultEnableRenewAfterExpiry, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, @@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &JWK{ + + p := &JWK{ Name: name, Type: "JWK", Key: &public, EncryptedKey: encrypted, Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { @@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } pubKeys := []interface{}{fooPub, barPub} if inputPubKey != nil { pubKeys = append(pubKeys, inputPubKey) } - return &K8sSA{ - Name: K8sSAName, - Type: "K8sSA", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - pubKeys: pubKeys, - }, nil + p := &K8sSA{ + Name: K8sSAName, + Type: "K8sSA", + Claims: &globalProvisionerClaims, + pubKeys: pubKeys, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateSSHPOP() (*SSHPOP, error) { @@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") if err != nil { return nil, err @@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) { return nil, err } - return &SSHPOP{ - Name: name, - Type: "SSHPOP", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, + p := &SSHPOP{ + Name: name, + Type: "SSHPOP", + Claims: &globalProvisionerClaims, sshPubKeys: &SSHKeys{ UserKeys: []ssh.PublicKey{userKey}, HostKeys: []ssh.PublicKey{hostKey}, }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateX5C(root []byte) (*X5C, error) { @@ -283,11 +279,6 @@ M46l92gdOozT if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - rootPool := x509.NewCertPool() var ( @@ -305,15 +296,17 @@ M46l92gdOozT } rootPool.AddCert(cert) } - return &X5C{ - Name: name, - Type: "X5C", - Roots: root, - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - rootPool: rootPool, - }, nil + p := &X5C{ + Name: name, + Type: "X5C", + Roots: root, + Claims: &globalProvisionerClaims, + rootPool: rootPool, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateOIDC() (*OIDC, error) { @@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &OIDC{ + p := &OIDC{ Name: name, Type: "OIDC", ClientID: clientID, @@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateGCP() (*GCP, error) { @@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &GCP{ + p := &GCP{ Type: "GCP", Name: name, ServiceAccounts: []string{serviceAccount}, Claims: &globalProvisionerClaims, - claimer: claimer, config: newGCPConfig(), keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - audiences: testAudiences.WithFragment("gcp/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("gcp/" + name), + }) + return p, err } func generateAWS() (*AWS, error) { @@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v2", "v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServer() (*AWS, *httptest.Server, error) { @@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { @@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } jwk, err := generateJSONWebKey() if err != nil { return nil, err } - return &Azure{ + p := &Azure{ Type: "Azure", Name: name, TenantID: tenantID, Audience: azureDefaultAudience, Claims: &globalProvisionerClaims, - claimer: claimer, config: newAzureConfig(tenantID), oidcConfig: openIDConfiguration{ Issuer: "https://sts.windows.net/" + tenantID + "/", @@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateAzureWithServer() (*Azure, *httptest.Server, error) { diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index aa44245d..6f534c76 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -26,15 +26,14 @@ type x5cPayload struct { // signature requests. type X5C struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences - rootPool *x509.CertPool + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + ctl *Controller + rootPool *x509.CertPool } // GetID returns the provisioner unique identifier. The name and credential id @@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a X5C type. -func (p *X5C) Init(config Config) error { +func (p *X5C) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -119,14 +118,9 @@ func (p *X5C) Init(config Config) error { return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -189,13 +183,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") } @@ -227,31 +221,30 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeX5C, p.Name, ""), - profileLimitDuration{p.claimer.DefaultTLSCertDuration(), - claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, + profileLimitDuration{ + p.ctl.Claimer.DefaultTLSCertDuration(), + claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter, + }, // validators commonNameValidator(claims.Subject), defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } @@ -314,11 +307,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, + &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 2959f8c6..330e6e7a 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -2,6 +2,7 @@ package provisioner import ( "context" + "crypto/x509" "net/http" "testing" "time" @@ -69,7 +70,7 @@ func TestX5C_Init(t *testing.T) { }, "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, + p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, @@ -141,7 +142,7 @@ M46l92gdOozT } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) + assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) if tc.extraValid != nil { assert.Nil(t, tc.extraValid(tc.p)) } @@ -473,9 +474,9 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: - assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) + claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) case commonNameValidator: @@ -484,8 +485,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { case defaultSANsValidator: assert.Equals(t, []string(v), tc.sans) case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -551,6 +552,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { } func TestX5C_AuthorizeRenew(t *testing.T) { + now := time.Now() type test struct { p *X5C code int @@ -563,12 +565,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -582,7 +584,10 @@ func TestX5C_AuthorizeRenew(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { + if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -618,7 +623,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { // disable sshCA enable := false p.Claims = &Claims{EnableSSHCA: &enable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, @@ -774,10 +779,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { case sshCertDefaultsModifier: assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) case *sshLimitDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) diff --git a/authority/tls_test.go b/authority/tls_test.go index aeadaf0f..07538701 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -757,7 +757,7 @@ func TestAuthority_Renew(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -798,7 +798,7 @@ func TestAuthority_Renew(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -856,7 +856,7 @@ func TestAuthority_Renew(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), @@ -956,7 +956,7 @@ func TestAuthority_Rekey(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -998,7 +998,7 @@ func TestAuthority_Rekey(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -1063,7 +1063,7 @@ func TestAuthority_Rekey(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(),