Refactor cloudkms signer to return an error on the constructor.

pull/488/head
Mariano Cano 4 years ago
parent cae08bff80
commit 163eb7029c

@ -140,19 +140,7 @@ func (k *CloudKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer,
if req.SigningKey == "" { if req.SigningKey == "" {
return nil, errors.New("signing key cannot be empty") return nil, errors.New("signing key cannot be empty")
} }
return NewSigner(k.client, req.SigningKey)
// Validate that the key exists
ctx, cancel := defaultContext()
defer cancel()
_, err := k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: req.SigningKey,
})
if err != nil {
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
}
return NewSigner(k.client, req.SigningKey), nil
} }
// CreateKey creates in Google's Cloud KMS a new asymmetric key for signing. // CreateKey creates in Google's Cloud KMS a new asymmetric key for signing.

@ -165,6 +165,15 @@ func TestCloudKMS_Close(t *testing.T) {
func TestCloudKMS_CreateSigner(t *testing.T) { func TestCloudKMS_CreateSigner(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
pemBytes, err := ioutil.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
type fields struct { type fields struct {
client KeyManagementClient client KeyManagementClient
} }
@ -178,8 +187,16 @@ func TestCloudKMS_CreateSigner(t *testing.T) {
want crypto.Signer want crypto.Signer
wantErr bool wantErr bool
}{ }{
{"ok", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName}, false}, {"ok", fields{&MockClient{
{"fail", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true}, getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false},
{"fail", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, fmt.Errorf("test error")
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -191,6 +208,9 @@ func TestCloudKMS_CreateSigner(t *testing.T) {
t.Errorf("CloudKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("CloudKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if signer, ok := got.(*Signer); ok {
signer.client = &MockClient{}
}
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CloudKMS.CreateSigner() = %v, want %v", got, tt.want) t.Errorf("CloudKMS.CreateSigner() = %v, want %v", got, tt.want)
} }

@ -13,33 +13,41 @@ import (
type Signer struct { type Signer struct {
client KeyManagementClient client KeyManagementClient
signingKey string signingKey string
publicKey crypto.PublicKey
} }
func NewSigner(c KeyManagementClient, signingKey string) *Signer { // NewSigner creates a new crypto.Signer the given CloudKMS signing key.
return &Signer{ func NewSigner(c KeyManagementClient, signingKey string) (*Signer, error) {
// Make sure that the key exists.
signer := &Signer{
client: c, client: c,
signingKey: signingKey, signingKey: signingKey,
} }
if err := signer.preloadKey(signingKey); err != nil {
return nil, err
}
return signer, nil
} }
// Public returns the public key of this signer or an error. func (s *Signer) preloadKey(signingKey string) error {
func (s *Signer) Public() crypto.PublicKey {
ctx, cancel := defaultContext() ctx, cancel := defaultContext()
defer cancel() defer cancel()
response, err := s.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{ response, err := s.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: s.signingKey, Name: signingKey,
}) })
if err != nil { if err != nil {
return errors.Wrap(err, "cloudKMS GetPublicKey failed") return errors.Wrap(err, "cloudKMS GetPublicKey failed")
} }
pk, err := pemutil.ParseKey([]byte(response.Pem)) s.publicKey, err = pemutil.ParseKey([]byte(response.Pem))
if err != nil {
return err return err
} }
return pk // Public returns the public key of this signer or an error.
func (s *Signer) Public() crypto.PublicKey {
return s.publicKey
} }
// Sign signs digest with the private key stored in Google's Cloud KMS. // Sign signs digest with the private key stored in Google's Cloud KMS.

@ -16,6 +16,15 @@ import (
) )
func Test_newSigner(t *testing.T) { func Test_newSigner(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
type args struct { type args struct {
c KeyManagementClient c KeyManagementClient
signingKey string signingKey string
@ -24,22 +33,42 @@ func Test_newSigner(t *testing.T) {
name string name string
args args args args
want *Signer want *Signer
wantErr bool
}{ }{
{"ok", args{&MockClient{}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey"}}, {"ok", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey", publicKey: pk}, false},
{"fail get public key", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, fmt.Errorf("an error")
},
}, "signingKey"}, nil, true},
{"fail parse pem", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
},
}, "signingKey"}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := NewSigner(tt.args.c, tt.args.signingKey); !reflect.DeepEqual(got, tt.want) { got, err := NewSigner(tt.args.c, tt.args.signingKey)
t.Errorf("newSigner() = %v, want %v", got, tt.want) if (err != nil) != tt.wantErr {
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
got.client = &MockClient{}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
} }
}) })
} }
} }
func Test_signer_Public(t *testing.T) { func Test_signer_Public(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
testError := fmt.Errorf("an error")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := ioutil.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -52,42 +81,23 @@ func Test_signer_Public(t *testing.T) {
type fields struct { type fields struct {
client KeyManagementClient client KeyManagementClient
signingKey string signingKey string
publicKey crypto.PublicKey
} }
tests := []struct { tests := []struct {
name string name string
fields fields fields fields
want crypto.PublicKey want crypto.PublicKey
wantErr bool
}{ }{
{"ok", fields{&MockClient{ {"ok", fields{&MockClient{}, "signingKey", pk}, pk},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}, keyName}, pk, false},
{"fail get public key", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return nil, testError
},
}, keyName}, nil, true},
{"fail parse pem", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
},
}, keyName}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &Signer{ s := &Signer{
client: tt.fields.client, client: tt.fields.client,
signingKey: tt.fields.signingKey, signingKey: tt.fields.signingKey,
publicKey: tt.fields.publicKey,
} }
got := s.Public() if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
if _, ok := got.(error); ok != tt.wantErr {
t.Errorf("signer.Public() error = %v, wantErr %v", got, tt.wantErr)
return
}
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
t.Errorf("signer.Public() = %v, want %v", got, tt.want) t.Errorf("signer.Public() = %v, want %v", got, tt.want)
} }
}) })

Loading…
Cancel
Save