package awskms import ( "context" "crypto" "fmt" "os" "path/filepath" "reflect" "testing" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/cli/crypto/pemutil" ) func TestNew(t *testing.T) { ctx := context.Background() sess, err := session.NewSessionWithOptions(session.Options{}) if err != nil { t.Fatal(err) } expected := &KMS{ session: sess, service: kms.New(sess), } // This will force an error in the session creation. // It does not fail with missing credentials. forceError := func(t *testing.T) { key := "AWS_CA_BUNDLE" value := os.Getenv(key) os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt")) t.Cleanup(func() { if value == "" { os.Unsetenv(key) } else { os.Setenv(key, value) } }) } type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want *KMS wantErr bool }{ {"ok", args{ctx, apiv1.Options{}}, expected, false}, {"ok with options", args{ctx, apiv1.Options{ Region: "us-east-1", Profile: "smallstep", CredentialsFile: "~/aws/credentials", }}, expected, false}, {"fail", args{ctx, apiv1.Options{}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Force an error in the session loading if tt.wantErr { forceError(t) } got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { if !reflect.DeepEqual(got, tt.want) { t.Errorf("New() = %#v, want %#v", got, tt.want) } } else { if got.session == nil || got.service == nil { t.Errorf("New() = %#v, want %#v", got, tt.want) } } }) } } func TestKMS_GetPublicKey(t *testing.T) { okClient := getOKClient() key, err := pemutil.ParseKey([]byte(publicKey)) if err != nil { t.Fatal(err) } type fields struct { session *session.Session service KeyManagementClient } type args struct { req *apiv1.GetPublicKeyRequest } tests := []struct { name string fields fields args args want crypto.PublicKey wantErr bool }{ {"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, key, false}, {"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true}, {"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=", }}, nil, true}, {"fail getPublicKey", fields{nil, &MockClient{ getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { return nil, fmt.Errorf("an error") }, }}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, nil, true}, {"fail not der", fields{nil, &MockClient{ getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { return &kms.GetPublicKeyOutput{ KeyId: input.KeyId, PublicKey: []byte(publicKey), }, nil }, }}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ session: tt.fields.session, service: tt.fields.service, } got, err := k.GetPublicKey(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want) } }) } } func TestKMS_CreateKey(t *testing.T) { okClient := getOKClient() key, err := pemutil.ParseKey([]byte(publicKey)) if err != nil { t.Fatal(err) } type fields struct { session *session.Session service KeyManagementClient } type args struct { req *apiv1.CreateKeyRequest } tests := []struct { name string fields fields args args want *apiv1.CreateKeyResponse wantErr bool }{ {"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, &apiv1.CreateKeyResponse{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", PublicKey: key, CreateSignerRequest: apiv1.CreateSignerRequest{ SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }, }, false}, {"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048, }}, &apiv1.CreateKeyResponse{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", PublicKey: key, CreateSignerRequest: apiv1.CreateSignerRequest{ SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }, }, false}, {"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true}, {"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.PureEd25519, }}, nil, true}, {"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 1234, }}, nil, true}, {"fail createKey", fields{nil, &MockClient{ createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { return nil, fmt.Errorf("an error") }, createAliasWithContext: okClient.createAliasWithContext, getPublicKeyWithContext: okClient.getPublicKeyWithContext, }}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, nil, true}, {"fail createAlias", fields{nil, &MockClient{ createKeyWithContext: okClient.createKeyWithContext, createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { return nil, fmt.Errorf("an error") }, getPublicKeyWithContext: okClient.getPublicKeyWithContext, }}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, nil, true}, {"fail getPublicKey", fields{nil, &MockClient{ createKeyWithContext: okClient.createKeyWithContext, createAliasWithContext: okClient.createAliasWithContext, getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { return nil, fmt.Errorf("an error") }, }}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ session: tt.fields.session, service: tt.fields.service, } got, err := k.CreateKey(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want) } }) } } func TestKMS_CreateSigner(t *testing.T) { client := getOKClient() key, err := pemutil.ParseKey([]byte(publicKey)) if err != nil { t.Fatal(err) } type fields struct { session *session.Session service KeyManagementClient } type args struct { req *apiv1.CreateSignerRequest } tests := []struct { name string fields fields args args want crypto.Signer wantErr bool }{ {"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{ SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, &Signer{ service: client, keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936", publicKey: key, }, false}, {"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, {"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ session: tt.fields.session, service: tt.fields.service, } got, err := k.CreateSigner(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want) } }) } } func TestKMS_Close(t *testing.T) { type fields struct { session *session.Session service KeyManagementClient } tests := []struct { name string fields fields wantErr bool }{ {"ok", fields{nil, getOKClient()}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ session: tt.fields.session, service: tt.fields.service, } if err := k.Close(); (err != nil) != tt.wantErr { t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_parseKeyID(t *testing.T) { type args struct { name string } tests := []struct { name string args args want string wantErr bool }{ {"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false}, {"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false}, {"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false}, {"fail parse", args{"awskms:key-id=%ZZ"}, "", true}, {"fail empty key", args{"awskms:key-id="}, "", true}, {"fail missing", args{"awskms:foo=bar"}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseKeyID(tt.args.name) if (err != nil) != tt.wantErr { t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("parseKeyID() = %v, want %v", got, tt.want) } }) } }