diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index e6ac3359..717aa6e7 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -139,9 +139,7 @@ func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) { // generates a token with them. func (p *Azure) GetIdentityToken() (string, error) { // Initialize the config if this method is used from the cli. - if err := p.assertConfig(); err != nil { - return "", err - } + p.assertConfig() req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody) if err != nil { @@ -183,9 +181,7 @@ func (p *Azure) Init(config Config) (err error) { p.Audience = azureDefaultAudience } // Initialize config - if err := p.assertConfig(); err != nil { - return err - } + p.assertConfig() // Update claims with global ones if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { @@ -296,10 +292,8 @@ func (p *Azure) AuthorizeRevoke(token string) error { } // assertConfig initializes the config if it has not been initialized -func (p *Azure) assertConfig() error { - if p.config != nil { - return nil +func (p *Azure) assertConfig() { + if p.config == nil { + p.config = newAzureConfig(p.TenantID) } - p.config = newAzureConfig(p.TenantID) - return nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index c986b5ce..5b3c817c 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -1,9 +1,15 @@ package provisioner import ( + "crypto/sha256" "crypto/x509" - "reflect" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/smallstep/assert" ) @@ -28,39 +34,44 @@ func TestAzure_Getters(t *testing.T) { } func TestAzure_GetTokenID(t *testing.T) { - type fields struct { - Type string - Name string - DisableCustomSANs bool - DisableTrustOnFirstUse bool - Claims *Claims - claimer *Claimer - config *azureConfig - } + p1, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + + p2, err := generateAzure() + assert.FatalError(t, err) + p2.TenantID = p1.TenantID + p2.config = p1.config + p2.oidcConfig = p1.oidcConfig + p2.keyStore = p1.keyStore + p2.DisableTrustOnFirstUse = true + + t1, err := p1.GetIdentityToken() + assert.FatalError(t, err) + t2, err := p2.GetIdentityToken() + assert.FatalError(t, err) + + sum := sha256.Sum256([]byte("/subscriptions/subscriptionID/resourceGroups/resourceGroup/providers/Microsoft.Compute/virtualMachines/virtualMachine")) + w1 := strings.ToLower(hex.EncodeToString(sum[:])) + type args struct { token string } tests := []struct { name string - fields fields + azure *Azure args args want string wantErr bool }{ - // TODO: Add test cases. + {"ok", p1, args{t1}, w1, false}, + {"ok no TOFU", p2, args{t2}, "the-jti", false}, + {"fail token", p1, args{"bad-token"}, "", true}, + {"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := &Azure{ - Type: tt.fields.Type, - Name: tt.fields.Name, - DisableCustomSANs: tt.fields.DisableCustomSANs, - DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, - Claims: tt.fields.Claims, - claimer: tt.fields.claimer, - config: tt.fields.config, - } - got, err := p.GetTokenID(tt.args.token) + got, err := tt.azure.GetTokenID(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) return @@ -72,8 +83,58 @@ func TestAzure_GetTokenID(t *testing.T) { } } +func TestAzure_GetIdentityToken(t *testing.T) { + p1, err := generateAzure() + assert.FatalError(t, err) + + t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/bad-request": + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + case "/bad-json": + w.Write([]byte(t1)) + default: + w.Header().Add("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{"access_token":"%s"}`, t1))) + } + })) + defer srv.Close() + + tests := []struct { + name string + azure *Azure + identityTokenURL string + want string + wantErr bool + }{ + {"ok", p1, srv.URL, t1, false}, + {"fail request", p1, srv.URL + "/bad-request", "", true}, + {"fail unmarshal", p1, srv.URL + "/bad-json", "", true}, + {"fail url", p1, "://ca.smallstep.com", "", true}, + {"fail connect", p1, "foobarzar", "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.azure.config.identityTokenURL = tt.identityTokenURL + got, err := tt.azure.GetIdentityToken() + if (err != nil) != tt.wantErr { + t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Azure.GetIdentityToken() = %v, want %v", got, tt.want) + } + }) + } +} + func TestAzure_Init(t *testing.T) { - az, srv, err := generateAzureWithServer() + p1, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() @@ -84,13 +145,25 @@ func TestAzure_Init(t *testing.T) { DefaultTLSDur: &Duration{0}, } + badDiscoveryURL := &azureConfig{ + oidcDiscoveryURL: srv.URL + "/error", + identityTokenURL: p1.config.identityTokenURL, + } + badJWKURL := &azureConfig{ + oidcDiscoveryURL: srv.URL + "/openid-configuration-fail-jwk", + identityTokenURL: p1.config.identityTokenURL, + } + badAzureConfig := &azureConfig{ + oidcDiscoveryURL: srv.URL + "/openid-configuration-no-issuer", + identityTokenURL: p1.config.identityTokenURL, + } + type fields struct { - Type string - Name string - TenantID string - DisableCustomSANs bool - DisableTrustOnFirstUse bool - Claims *Claims + Type string + Name string + TenantID string + Claims *Claims + config *azureConfig } type args struct { config Config @@ -101,25 +174,24 @@ func TestAzure_Init(t *testing.T) { args args wantErr bool }{ - {"ok", fields{az.Type, az.Name, az.TenantID, false, false, nil}, args{config}, false}, - {"ok", fields{az.Type, az.Name, az.TenantID, true, false, nil}, args{config}, false}, - {"ok", fields{az.Type, az.Name, az.TenantID, false, true, nil}, args{config}, false}, - {"ok", fields{az.Type, az.Name, az.TenantID, true, true, nil}, args{config}, false}, - {"fail type", fields{"", az.Name, az.TenantID, false, false, nil}, args{config}, true}, - {"fail name", fields{az.Type, "", az.TenantID, false, false, nil}, args{config}, true}, - {"fail tenant id", fields{az.Type, az.Name, "", false, false, nil}, args{config}, true}, - {"fail claims", fields{az.Type, az.Name, az.TenantID, false, false, badClaims}, args{config}, true}, + {"ok", fields{p1.Type, p1.Name, p1.TenantID, nil, p1.config}, args{config}, false}, + {"ok with config", fields{p1.Type, p1.Name, p1.TenantID, nil, p1.config}, args{config}, false}, + {"fail type", fields{"", p1.Name, p1.TenantID, nil, p1.config}, args{config}, true}, + {"fail name", fields{p1.Type, "", p1.TenantID, nil, p1.config}, args{config}, true}, + {"fail tenant id", fields{p1.Type, p1.Name, "", nil, p1.config}, args{config}, true}, + {"fail claims", fields{p1.Type, p1.Name, p1.TenantID, badClaims, p1.config}, args{config}, true}, + {"fail discovery URL", fields{p1.Type, p1.Name, p1.TenantID, nil, badDiscoveryURL}, args{config}, true}, + {"fail JWK URL", fields{p1.Type, p1.Name, p1.TenantID, nil, badJWKURL}, args{config}, true}, + {"fail config Validate", fields{p1.Type, p1.Name, p1.TenantID, nil, badAzureConfig}, args{config}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Azure{ - Type: tt.fields.Type, - Name: tt.fields.Name, - TenantID: tt.fields.TenantID, - DisableCustomSANs: tt.fields.DisableCustomSANs, - DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, - Claims: tt.fields.Claims, - config: az.config, + Type: tt.fields.Type, + Name: tt.fields.Name, + TenantID: tt.fields.TenantID, + Claims: tt.fields.Claims, + config: tt.fields.config, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr) @@ -129,46 +201,101 @@ func TestAzure_Init(t *testing.T) { } func TestAzure_AuthorizeSign(t *testing.T) { - type fields struct { - Type string - Name string - DisableCustomSANs bool - DisableTrustOnFirstUse bool - Claims *Claims - claimer *Claimer - config *azureConfig - } + p1, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + + p2, err := generateAzure() + assert.FatalError(t, err) + p2.TenantID = p1.TenantID + p2.Subscriptions = []string{"subscriptionID"} + p2.config = p1.config + p2.oidcConfig = p1.oidcConfig + p2.keyStore = p1.keyStore + p2.DisableCustomSANs = true + + p3, err := generateAzure() + assert.FatalError(t, err) + p3.config = p1.config + p3.oidcConfig = p1.oidcConfig + p3.keyStore = p1.keyStore + + p4, err := generateAzure() + assert.FatalError(t, err) + p4.TenantID = p1.TenantID + p4.Subscriptions = []string{"subscriptionID1"} + p4.config = p1.config + p4.oidcConfig = p1.oidcConfig + p4.keyStore = p1.keyStore + + badKey, err := generateJSONWebKey() + assert.FatalError(t, err) + + t1, err := p1.GetIdentityToken() + assert.FatalError(t, err) + t2, err := p2.GetIdentityToken() + assert.FatalError(t, err) + t3, err := p3.GetIdentityToken() + assert.FatalError(t, err) + t4, err := p4.GetIdentityToken() + assert.FatalError(t, err) + + t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + + failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), badKey) + assert.FatalError(t, err) + type args struct { token string } tests := []struct { name string - fields fields + azure *Azure args args - want []SignOption + wantLen int wantErr bool }{ - // TODO: Add test cases. + {"ok", p1, args{t1}, 4, false}, + {"ok", p2, args{t2}, 5, false}, + {"ok", p1, args{t11}, 4, false}, + {"fail tenant", p3, args{t3}, 0, true}, + {"fail subscription", p4, args{t4}, 0, true}, + {"fail token", p1, args{"token"}, 0, true}, + {"fail issuer", p1, args{failIssuer}, 0, true}, + {"fail audience", p1, args{failAudience}, 0, true}, + {"fail exp", p1, args{failExp}, 0, true}, + {"fail nbf", p1, args{failNbf}, 0, true}, + {"fail key", p1, args{failKey}, 0, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := &Azure{ - Type: tt.fields.Type, - Name: tt.fields.Name, - DisableCustomSANs: tt.fields.DisableCustomSANs, - DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, - Claims: tt.fields.Claims, - claimer: tt.fields.claimer, - config: tt.fields.config, - } - got, err := p.AuthorizeSign(tt.args.token) + got, err := tt.azure.AuthorizeSign(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Azure.AuthorizeSign() = %v, want %v", got, tt.want) - } + assert.Len(t, tt.wantLen, got) }) } } @@ -207,40 +334,51 @@ func TestAzure_AuthorizeRenewal(t *testing.T) { } func TestAzure_AuthorizeRevoke(t *testing.T) { - type fields struct { - Type string - Name string - DisableCustomSANs bool - DisableTrustOnFirstUse bool - Claims *Claims - claimer *Claimer - config *azureConfig - } + az, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + + token, err := az.GetIdentityToken() + assert.FatalError(t, err) + type args struct { token string } tests := []struct { name string - fields fields + azure *Azure args args wantErr bool }{ - // TODO: Add test cases. + {"ok token", az, args{token}, true}, // revoke is disabled + {"bad token", az, args{"bad token"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := &Azure{ - Type: tt.fields.Type, - Name: tt.fields.Name, - DisableCustomSANs: tt.fields.DisableCustomSANs, - DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, - Claims: tt.fields.Claims, - claimer: tt.fields.claimer, - config: tt.fields.config, - } - if err := p.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { + if err := tt.azure.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) } }) } } + +func TestAzure_assertConfig(t *testing.T) { + p1, err := generateAzure() + assert.FatalError(t, err) + p2, err := generateAzure() + assert.FatalError(t, err) + p2.config = nil + + tests := []struct { + name string + azure *Azure + }{ + {"ok with config", p1}, + {"ok no config", p2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.azure.assertConfig() + }) + } +} diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 94fc7015..23175677 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -350,6 +350,7 @@ func generateAzure() (*Azure, error) { Type: "Azure", Name: name, TenantID: tenantID, + Audience: azureDefaultAudience, Claims: &globalProvisionerClaims, claimer: claimer, config: newAzureConfig(tenantID), @@ -394,6 +395,10 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) case "/" + az.TenantID + "/.well-known/openid-configuration": writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"}) + case "/openid-configuration-no-issuer": + writeJSON(w, openIDConfiguration{Issuer: "", JWKSetURI: srv.URL + "/jwks_uri"}) + case "/openid-configuration-fail-jwk": + writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/error"}) case "/random": keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) w.Header().Add("Cache-Control", "max-age=5") @@ -579,6 +584,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, + ID: "the-jti", }, AppID: "the-appid", AppIDAcr: "the-appidacr",