diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 3e777e08..56a96490 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -386,6 +386,9 @@ func (p *AWS) readURL(url string) ([]byte, error) { default: return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v) } + if resp != nil { + resp.Body.Close() + } } // all versions have been exhausted and we haven't returned successfully yet so pass diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 50cdd32a..b5728de4 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -187,6 +187,31 @@ func TestAWS_GetIdentityToken(t *testing.T) { } } +func TestAWS_GetIdentityTokenV1Only(t *testing.T) { + aws, srv, err := generateAWSWithServerV1Only() + assert.FatalError(t, err) + defer srv.Close() + + subject := "foo.local" + caURL := "https://ca.smallstep.com" + u, err := url.Parse(caURL) + assert.Nil(t, err) + + token, err := aws.GetIdentityToken(subject, caURL) + assert.Nil(t, err) + + _, c, err := parseAWSToken(token) + if assert.NoError(t, err) { + assert.Equals(t, awsIssuer, c.Issuer) + assert.Equals(t, subject, c.Subject) + assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience) + assert.Equals(t, aws.Accounts[0], c.document.AccountID) + err = aws.config.certificate.CheckSignature( + aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature) + assert.NoError(t, err) + } +} + func TestAWS_Init(t *testing.T) { config := Config{ Claims: globalProvisionerClaims, @@ -203,6 +228,7 @@ func TestAWS_Init(t *testing.T) { DisableCustomSANs bool DisableTrustOnFirstUse bool InstanceAge Duration + IMDSVersions []string Claims *Claims } type args struct { @@ -214,12 +240,15 @@ func TestAWS_Init(t *testing.T) { args args wantErr bool }{ - {"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, nil}, args{config}, false}, - {"ok", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, nil}, args{config}, false}, - {"fail type ", fields{"", "name", []string{"account"}, false, false, zero, nil}, args{config}, true}, - {"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, nil}, args{config}, true}, - {"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, nil}, args{config}, true}, - {"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, badClaims}, args{config}, true}, + {"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, false}, + {"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, nil}, args{config}, false}, + {"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, nil}, args{config}, false}, + {"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, false}, + {"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true}, + {"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true}, + {"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, true}, + {"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, nil}, args{config}, true}, + {"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, badClaims}, args{config}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -230,6 +259,7 @@ func TestAWS_Init(t *testing.T) { DisableCustomSANs: tt.fields.DisableCustomSANs, DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, InstanceAge: tt.fields.InstanceAge, + IMDSVersions: tt.fields.IMDSVersions, Claims: tt.fields.Claims, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 84283631..52fb470a 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -495,6 +495,101 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) { return aws, srv, nil } +func generateAWSV1Only() (*AWS, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + accountID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + claimer, err := NewClaimer(nil, globalProvisionerClaims) + if err != nil { + return nil, err + } + block, _ := pem.Decode([]byte(awsTestCertificate)) + if block == nil || block.Type != "CERTIFICATE" { + return nil, errors.New("error decoding AWS certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrap(err, "error parsing AWS certificate") + } + return &AWS{ + Type: "AWS", + Name: name, + Accounts: []string{accountID}, + Claims: &globalProvisionerClaims, + IMDSVersions: []string{"v1"}, + claimer: claimer, + config: &awsConfig{ + identityURL: awsIdentityURL, + signatureURL: awsSignatureURL, + tokenURL: awsAPITokenURL, + tokenTTL: awsAPITokenTTL, + certificate: cert, + signatureAlgorithm: awsSignatureAlgorithm, + }, + audiences: testAudiences.WithFragment("aws/" + name), + }, nil +} + +func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { + aws, err := generateAWSV1Only() + if err != nil { + return nil, nil, err + } + block, _ := pem.Decode([]byte(awsTestKey)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, nil, errors.New("error decoding AWS key") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(err, "error parsing AWS private key") + } + doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ + AccountID: aws.Accounts[0], + Architecture: "x86_64", + AvailabilityZone: "us-west-2b", + ImageID: "image-id", + InstanceID: "instance-id", + InstanceType: "t2.micro", + PendingTime: time.Now(), + PrivateIP: "127.0.0.1", + Region: "us-west-1", + Version: "2017-09-30", + }, "", " ") + if err != nil { + return nil, nil, err + } + + sum := sha256.Sum256(doc) + signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) + if err != nil { + return nil, nil, errors.Wrap(err, "error signing document") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/dynamic/instance-identity/document": + w.Write(doc) + case "/latest/dynamic/instance-identity/signature": + w.Write([]byte(base64.StdEncoding.EncodeToString(signature))) + case "/bad-document": + w.Write([]byte("{}")) + case "/bad-signature": + w.Write([]byte("YmFkLXNpZ25hdHVyZQo=")) + case "/bad-json": + w.Write([]byte("{")) + default: + http.NotFound(w, r) + } + })) + aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document" + aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature" + return aws, srv, nil +} + func generateAzure() (*Azure, error) { name, err := randutil.Alphanumeric(10) if err != nil {