Merge branch 'master' into hs/acme-eab
commit
bcd1240a0e
@ -0,0 +1,80 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/smallstep/certificates/kms/azurekms (interfaces: KeyVaultClient)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// KeyVaultClient is a mock of KeyVaultClient interface
|
||||
type KeyVaultClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *KeyVaultClientMockRecorder
|
||||
}
|
||||
|
||||
// KeyVaultClientMockRecorder is the mock recorder for KeyVaultClient
|
||||
type KeyVaultClientMockRecorder struct {
|
||||
mock *KeyVaultClient
|
||||
}
|
||||
|
||||
// NewKeyVaultClient creates a new mock instance
|
||||
func NewKeyVaultClient(ctrl *gomock.Controller) *KeyVaultClient {
|
||||
mock := &KeyVaultClient{ctrl: ctrl}
|
||||
mock.recorder = &KeyVaultClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *KeyVaultClient) EXPECT() *KeyVaultClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateKey mocks base method
|
||||
func (m *KeyVaultClient) CreateKey(arg0 context.Context, arg1, arg2 string, arg3 keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateKey", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(keyvault.KeyBundle)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateKey indicates an expected call of CreateKey
|
||||
func (mr *KeyVaultClientMockRecorder) CreateKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*KeyVaultClient)(nil).CreateKey), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// GetKey mocks base method
|
||||
func (m *KeyVaultClient) GetKey(arg0 context.Context, arg1, arg2, arg3 string) (keyvault.KeyBundle, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKey", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(keyvault.KeyBundle)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKey indicates an expected call of GetKey
|
||||
func (mr *KeyVaultClientMockRecorder) GetKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*KeyVaultClient)(nil).GetKey), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// Sign mocks base method
|
||||
func (m *KeyVaultClient) Sign(arg0 context.Context, arg1, arg2, arg3 string, arg4 keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(keyvault.KeyOperationResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Sign indicates an expected call of Sign
|
||||
func (mr *KeyVaultClientMockRecorder) Sign(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*KeyVaultClient)(nil).Sign), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
@ -0,0 +1,342 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/uri"
|
||||
)
|
||||
|
||||
func init() {
|
||||
apiv1.Register(apiv1.AzureKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) {
|
||||
return New(ctx, opts)
|
||||
})
|
||||
}
|
||||
|
||||
// Scheme is the scheme used for the Azure Key Vault uris.
|
||||
const Scheme = "azurekms"
|
||||
|
||||
// keyIDRegexp is the regular expression that Key Vault uses on the kid. We can
|
||||
// extract the vault, name and version of the key.
|
||||
var keyIDRegexp = regexp.MustCompile(`^https://([0-9a-zA-Z-]+)\.vault\.azure\.net/keys/([0-9a-zA-Z-]+)/([0-9a-zA-Z-]+)$`)
|
||||
|
||||
var (
|
||||
valueTrue = true
|
||||
value2048 int32 = 2048
|
||||
value3072 int32 = 3072
|
||||
value4096 int32 = 4096
|
||||
)
|
||||
|
||||
var now = func() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
type keyType struct {
|
||||
Kty keyvault.JSONWebKeyType
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
}
|
||||
|
||||
func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType {
|
||||
switch k.Kty {
|
||||
case keyvault.EC:
|
||||
if pl == apiv1.HSM {
|
||||
return keyvault.ECHSM
|
||||
}
|
||||
return k.Kty
|
||||
case keyvault.RSA:
|
||||
if pl == apiv1.HSM {
|
||||
return keyvault.RSAHSM
|
||||
}
|
||||
return k.Kty
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]keyType{
|
||||
apiv1.UnspecifiedSignAlgorithm: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P256,
|
||||
},
|
||||
apiv1.SHA256WithRSA: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA384WithRSA: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA512WithRSA: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA256WithRSAPSS: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA384WithRSAPSS: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.SHA512WithRSAPSS: {
|
||||
Kty: keyvault.RSA,
|
||||
},
|
||||
apiv1.ECDSAWithSHA256: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P256,
|
||||
},
|
||||
apiv1.ECDSAWithSHA384: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P384,
|
||||
},
|
||||
apiv1.ECDSAWithSHA512: {
|
||||
Kty: keyvault.EC,
|
||||
Curve: keyvault.P521,
|
||||
},
|
||||
}
|
||||
|
||||
// vaultResource is the value the client will use as audience.
|
||||
const vaultResource = "https://vault.azure.net"
|
||||
|
||||
// KeyVaultClient is the interface implemented by keyvault.BaseClient. It will
|
||||
// be used for testing purposes.
|
||||
type KeyVaultClient interface {
|
||||
GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (keyvault.KeyBundle, error)
|
||||
CreateKey(ctx context.Context, vaultBaseURL string, keyName string, parameters keyvault.KeyCreateParameters) (keyvault.KeyBundle, error)
|
||||
Sign(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string, parameters keyvault.KeySignParameters) (keyvault.KeyOperationResult, error)
|
||||
}
|
||||
|
||||
// KeyVault implements a KMS using Azure Key Vault.
|
||||
//
|
||||
// The URI format used in Azure Key Vault is the following:
|
||||
//
|
||||
// - azurekms:name=key-name;vault=vault-name
|
||||
// - azurekms:name=key-name;vault=vault-name?version=key-version
|
||||
// - azurekms:name=key-name;vault=vault-name?hsm=true
|
||||
//
|
||||
// The scheme is "azurekms"; "name" is the key name; "vault" is the key vault
|
||||
// name where the key is located; "version" is an optional parameter that
|
||||
// defines the version of they key, if version is not given, the latest one will
|
||||
// be used; "hsm" defines if an HSM want to be used for this key, this is
|
||||
// specially useful when this is used from `step`.
|
||||
//
|
||||
// TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault
|
||||
// package, at some point Azure might create a keyvault client with all the
|
||||
// functionality in /sdk/keyvault, we should migrate to that once available.
|
||||
type KeyVault struct {
|
||||
baseClient KeyVaultClient
|
||||
defaults DefaultOptions
|
||||
}
|
||||
|
||||
// DefaultOptions are custom options that can be passed as defaults using the
|
||||
// URI in apiv1.Options.
|
||||
type DefaultOptions struct {
|
||||
Vault string
|
||||
ProtectionLevel apiv1.ProtectionLevel
|
||||
}
|
||||
|
||||
var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
baseClient := keyvault.New()
|
||||
|
||||
// With an URI, try to log in only using client credentials in the URI.
|
||||
// Client credentials requires:
|
||||
// - client-id
|
||||
// - client-secret
|
||||
// - tenant-id
|
||||
// And optionally the aad-endpoint to support custom clouds:
|
||||
// - aad-endpoint (defaults to https://login.microsoftonline.com/)
|
||||
if opts.URI != "" {
|
||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Required options
|
||||
clientID := u.Get("client-id")
|
||||
clientSecret := u.Get("client-secret")
|
||||
tenantID := u.Get("tenant-id")
|
||||
// optional
|
||||
aadEndpoint := u.Get("aad-endpoint")
|
||||
|
||||
if clientID != "" && clientSecret != "" && tenantID != "" {
|
||||
s := auth.EnvironmentSettings{
|
||||
Values: map[string]string{
|
||||
auth.ClientID: clientID,
|
||||
auth.ClientSecret: clientSecret,
|
||||
auth.TenantID: tenantID,
|
||||
auth.Resource: vaultResource,
|
||||
},
|
||||
Environment: azure.PublicCloud,
|
||||
}
|
||||
if aadEndpoint != "" {
|
||||
s.Environment.ActiveDirectoryEndpoint = aadEndpoint
|
||||
}
|
||||
baseClient.Authorizer, err = s.GetAuthorizer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return baseClient, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to authorize with the following methods:
|
||||
// 1. Environment variables.
|
||||
// - Client credentials
|
||||
// - Client certificate
|
||||
// - Username and password
|
||||
// - MSI
|
||||
// 2. Using Azure CLI 2.0 on local development.
|
||||
authorizer, err := auth.NewAuthorizerFromEnvironmentWithResource(vaultResource)
|
||||
if err != nil {
|
||||
authorizer, err = auth.NewAuthorizerFromCLIWithResource(vaultResource)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting authorizer for key vault")
|
||||
}
|
||||
}
|
||||
baseClient.Authorizer = authorizer
|
||||
return &baseClient, nil
|
||||
}
|
||||
|
||||
// New initializes a new KMS implemented using Azure Key Vault.
|
||||
func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) {
|
||||
baseClient, err := createClient(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// step and step-ca do not need and URI, but having a default vault and
|
||||
// protection level is useful if this package is used as an api
|
||||
var defaults DefaultOptions
|
||||
if opts.URI != "" {
|
||||
u, err := uri.ParseWithScheme(Scheme, opts.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defaults.Vault = u.Get("vault")
|
||||
if u.GetBool("hsm") {
|
||||
defaults.ProtectionLevel = apiv1.HSM
|
||||
}
|
||||
}
|
||||
|
||||
return &KeyVault{
|
||||
baseClient: baseClient,
|
||||
defaults: defaults,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPublicKey loads a public key from Azure Key Vault by its resource name.
|
||||
func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("getPublicKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
vault, name, version, _, err := parseKeyName(req.Name, k.defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := k.baseClient.GetKey(ctx, vaultBaseURL(vault), name, version)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault GetKey failed")
|
||||
}
|
||||
|
||||
return convertKey(resp.Key)
|
||||
}
|
||||
|
||||
// CreateKey creates a asymmetric key in Azure Key Vault.
|
||||
func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("createKeyRequest 'name' cannot be empty")
|
||||
}
|
||||
|
||||
vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Override protection level to HSM only if it's not specified, and is given
|
||||
// in the uri.
|
||||
protectionLevel := req.ProtectionLevel
|
||||
if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm {
|
||||
protectionLevel = apiv1.HSM
|
||||
}
|
||||
|
||||
kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm)
|
||||
}
|
||||
var keySize *int32
|
||||
if kt.Kty == keyvault.RSA || kt.Kty == keyvault.RSAHSM {
|
||||
switch req.Bits {
|
||||
case 2048:
|
||||
keySize = &value2048
|
||||
case 0, 3072:
|
||||
keySize = &value3072
|
||||
case 4096:
|
||||
keySize = &value4096
|
||||
default:
|
||||
return nil, errors.Errorf("keyVault does not support key size %d", req.Bits)
|
||||
}
|
||||
}
|
||||
|
||||
created := date.UnixTime(now())
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{
|
||||
Kty: kt.KeyType(protectionLevel),
|
||||
KeySize: keySize,
|
||||
Curve: kt.Curve,
|
||||
KeyOps: &[]keyvault.JSONWebKeyOperation{
|
||||
keyvault.Sign, keyvault.Verify,
|
||||
},
|
||||
KeyAttributes: &keyvault.KeyAttributes{
|
||||
Enabled: &valueTrue,
|
||||
Created: &created,
|
||||
NotBefore: &created,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault CreateKey failed")
|
||||
}
|
||||
|
||||
publicKey, err := convertKey(resp.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyURI := getKeyName(vault, name, resp)
|
||||
return &apiv1.CreateKeyResponse{
|
||||
Name: keyURI,
|
||||
PublicKey: publicKey,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: keyURI,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateSigner returns a crypto.Signer from a previously created asymmetric key.
|
||||
func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
||||
if req.SigningKey == "" {
|
||||
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
|
||||
}
|
||||
return NewSigner(k.baseClient, req.SigningKey, k.defaults)
|
||||
}
|
||||
|
||||
// Close closes the client connection to the Azure Key Vault. This is a noop.
|
||||
func (k *KeyVault) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateName validates that the given string is a valid URI.
|
||||
func (k *KeyVault) ValidateName(s string) error {
|
||||
_, _, _, _, err := parseKeyName(s, k.defaults)
|
||||
return err
|
||||
}
|
@ -0,0 +1,653 @@
|
||||
//go:generate mockgen -package mock -mock_names=KeyVaultClient=KeyVaultClient -destination internal/mock/key_vault_client.go github.com/smallstep/certificates/kms/azurekms KeyVaultClient
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/azurekms/internal/mock"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
var errTest = fmt.Errorf("test error")
|
||||
|
||||
func mockNow(t *testing.T) time.Time {
|
||||
old := now
|
||||
t0 := time.Unix(1234567890, 123).UTC()
|
||||
now = func() time.Time {
|
||||
return t0
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
now = old
|
||||
})
|
||||
return t0
|
||||
}
|
||||
|
||||
func mockClient(t *testing.T) *mock.KeyVaultClient {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(func() {
|
||||
ctrl.Finish()
|
||||
})
|
||||
return mock.NewKeyVaultClient(ctrl)
|
||||
}
|
||||
|
||||
func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(&jose.JSONWebKey{
|
||||
Key: pub,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key := new(keyvault.JSONWebKey)
|
||||
if err := json.Unmarshal(b, key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func Test_now(t *testing.T) {
|
||||
t0 := now()
|
||||
if loc := t0.Location(); loc != time.UTC {
|
||||
t.Errorf("now() Location = %v, want %v", loc, time.UTC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
client := mockClient(t)
|
||||
old := createClient
|
||||
t.Cleanup(func() {
|
||||
createClient = old
|
||||
})
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func()
|
||||
args args
|
||||
want *KeyVault
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{}}, &KeyVault{
|
||||
baseClient: client,
|
||||
}, false},
|
||||
{"ok with vault", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:vault=my-vault",
|
||||
}}, &KeyVault{
|
||||
baseClient: client,
|
||||
defaults: DefaultOptions{
|
||||
Vault: "my-vault",
|
||||
ProtectionLevel: apiv1.UnspecifiedProtectionLevel,
|
||||
},
|
||||
}, false},
|
||||
{"ok with vault + hsm", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:vault=my-vault;hsm=true",
|
||||
}}, &KeyVault{
|
||||
baseClient: client,
|
||||
defaults: DefaultOptions{
|
||||
Vault: "my-vault",
|
||||
ProtectionLevel: apiv1.HSM,
|
||||
},
|
||||
}, false},
|
||||
{"fail", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return nil, errTest
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{}}, nil, true},
|
||||
{"fail uri", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{
|
||||
URI: "kms:vault=my-vault;hsm=true",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setup()
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_createClient(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
skip bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{context.Background(), apiv1.Options{}}, true, false},
|
||||
{"ok with uri", args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id",
|
||||
}}, false, false},
|
||||
{"ok with uri+aad", args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F",
|
||||
}}, false, false},
|
||||
{"ok with uri no config", args{context.Background(), apiv1.Options{
|
||||
URI: "azurekms:",
|
||||
}}, true, false},
|
||||
{"fail uri", args{context.Background(), apiv1.Options{
|
||||
URI: "kms:client-id=id;client-secret=secret;tenant-id=id",
|
||||
}}, false, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.skip {
|
||||
t.SkipNow()
|
||||
}
|
||||
_, err := createClient(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_GetPublicKey(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
jwk := createJWK(t, pub)
|
||||
|
||||
client := mockClient(t)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
||||
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.GetPublicKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.PublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
}}, pub, false},
|
||||
{"ok with version", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key?version=my-version",
|
||||
}}, pub, false},
|
||||
{"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "",
|
||||
}}, nil, true},
|
||||
{"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail id", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=;name=?version=my-version",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
got, err := k.GetPublicKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyVault.GetPublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_CreateKey(t *testing.T) {
|
||||
ecKey, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rsaKey, err := keyutil.GenerateSigner("RSA", "", 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ecPub := ecKey.Public()
|
||||
rsaPub := rsaKey.Public()
|
||||
ecJWK := createJWK(t, ecPub)
|
||||
rsaJWK := createJWK(t, rsaPub)
|
||||
|
||||
t0 := date.UnixTime(mockNow(t))
|
||||
client := mockClient(t)
|
||||
|
||||
expects := []struct {
|
||||
Name string
|
||||
Kty keyvault.JSONWebKeyType
|
||||
KeySize *int32
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
Key *keyvault.JSONWebKey
|
||||
}{
|
||||
{"P-256", keyvault.EC, nil, keyvault.P256, ecJWK},
|
||||
{"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
|
||||
{"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK},
|
||||
{"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK},
|
||||
{"P-384", keyvault.EC, nil, keyvault.P384, ecJWK},
|
||||
{"P-521", keyvault.EC, nil, keyvault.P521, ecJWK},
|
||||
{"RSA 0", keyvault.RSA, &value3072, "", rsaJWK},
|
||||
{"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK},
|
||||
{"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK},
|
||||
{"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK},
|
||||
{"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK},
|
||||
{"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK},
|
||||
}
|
||||
|
||||
for _, e := range expects {
|
||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", keyvault.KeyCreateParameters{
|
||||
Kty: e.Kty,
|
||||
KeySize: e.KeySize,
|
||||
Curve: e.Curve,
|
||||
KeyOps: &[]keyvault.JSONWebKeyOperation{
|
||||
keyvault.Sign, keyvault.Verify,
|
||||
},
|
||||
KeyAttributes: &keyvault.KeyAttributes{
|
||||
Enabled: &valueTrue,
|
||||
Created: &t0,
|
||||
NotBefore: &t0,
|
||||
},
|
||||
}).Return(keyvault.KeyBundle{
|
||||
Key: e.Key,
|
||||
}, nil)
|
||||
}
|
||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{}, errTest)
|
||||
client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{
|
||||
Key: nil,
|
||||
}, nil)
|
||||
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *apiv1.CreateKeyResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok P-256", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
ProtectionLevel: apiv1.Software,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-256 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
ProtectionLevel: apiv1.HSM,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key?hsm=true",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-384", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA384,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok P-521", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA512,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: ecPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 0", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 0,
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
||||
ProtectionLevel: apiv1.Software,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 0 HSM", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 0,
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
||||
ProtectionLevel: apiv1.HSM,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key;hsm=true",
|
||||
Bits: 0,
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSAPSS,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 2048,
|
||||
SignatureAlgorithm: apiv1.SHA384WithRSA,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 3072", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 3072,
|
||||
SignatureAlgorithm: apiv1.SHA512WithRSA,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"ok RSA 4096", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=my-key",
|
||||
Bits: 4096,
|
||||
SignatureAlgorithm: apiv1.SHA512WithRSAPSS,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "azurekms:name=my-key;vault=my-vault",
|
||||
PublicKey: rsaPub,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:name=my-key;vault=my-vault",
|
||||
},
|
||||
}, false},
|
||||
{"fail createKey", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail convertKey", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail name", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "",
|
||||
}}, nil, true},
|
||||
{"fail vault", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail id", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail SignatureAlgorithm", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.PureEd25519,
|
||||
}}, nil, true},
|
||||
{"fail bit size", fields{client}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found",
|
||||
SignatureAlgorithm: apiv1.SHA384WithRSAPSS,
|
||||
Bits: 1024,
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
got, err := k.CreateKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyVault.CreateKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_CreateSigner(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
jwk := createJWK(t, pub)
|
||||
|
||||
client := mockClient(t)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
||||
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateSignerRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:vault=my-vault;name=my-key",
|
||||
}}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"ok with version", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:vault=my-vault;name=my-key;version=my-version",
|
||||
}}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "my-version",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"fail GetKey", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "azurekms:vault=my-vault;name=not-found;version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail SigningKey", fields{client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
got, err := k.CreateSigner(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyVault.CreateSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_Close(t *testing.T) {
|
||||
client := mockClient(t)
|
||||
type fields struct {
|
||||
baseClient KeyVaultClient
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{
|
||||
baseClient: tt.fields.baseClient,
|
||||
}
|
||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.Close() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keyType_KeyType(t *testing.T) {
|
||||
type fields struct {
|
||||
Kty keyvault.JSONWebKeyType
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
}
|
||||
type args struct {
|
||||
pl apiv1.ProtectionLevel
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want keyvault.JSONWebKeyType
|
||||
}{
|
||||
{"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC},
|
||||
{"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC},
|
||||
{"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM},
|
||||
{"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA},
|
||||
{"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA},
|
||||
{"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM},
|
||||
{"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := keyType{
|
||||
Kty: tt.fields.Kty,
|
||||
Curve: tt.fields.Curve,
|
||||
}
|
||||
if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyVault_ValidateName(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{"azurekms:name=my-key;vault=my-vault"}, false},
|
||||
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, false},
|
||||
{"fail scheme", args{"azure:name=my-key;vault=my-vault"}, true},
|
||||
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, true},
|
||||
{"fail no name", args{"azurekms:vault=my-vault"}, true},
|
||||
{"fail no vault", args{"azurekms:name=my-key"}, true},
|
||||
{"fail empty", args{""}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KeyVault{}
|
||||
if err := k.ValidateName(tt.args.s); (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyVault.ValidateName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,160 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"math/big"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
)
|
||||
|
||||
// Signer implements a crypto.Signer using the AWS KMS.
|
||||
type Signer struct {
|
||||
client KeyVaultClient
|
||||
vaultBaseURL string
|
||||
name string
|
||||
version string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
|
||||
// NewSigner creates a new signer using a key in the AWS KMS.
|
||||
func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) {
|
||||
vault, name, version, _, err := parseKeyName(signingKey, defaults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure that the key exists.
|
||||
signer := &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: vaultBaseURL(vault),
|
||||
name: name,
|
||||
version: version,
|
||||
}
|
||||
if err := signer.preloadKey(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *Signer) preloadKey() error {
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "keyVault GetKey failed")
|
||||
}
|
||||
|
||||
s.publicKey, err = convertKey(resp.Key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Public returns the public key of this signer or an error.
|
||||
func (s *Signer) Public() crypto.PublicKey {
|
||||
return s.publicKey
|
||||
}
|
||||
|
||||
// Sign signs digest with the private key stored in the AWS KMS.
|
||||
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
alg, err := getSigningAlgorithm(s.Public(), opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
||||
|
||||
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
||||
Algorithm: alg,
|
||||
Value: &b64,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault Sign failed")
|
||||
}
|
||||
|
||||
sig, err := base64.RawURLEncoding.DecodeString(*resp.Result)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error decoding keyVault Sign result")
|
||||
}
|
||||
|
||||
var octetSize int
|
||||
switch alg {
|
||||
case keyvault.ES256:
|
||||
octetSize = 32 // 256-bit, concat(R,S) = 64 bytes
|
||||
case keyvault.ES384:
|
||||
octetSize = 48 // 384-bit, concat(R,S) = 96 bytes
|
||||
case keyvault.ES512:
|
||||
octetSize = 66 // 528-bit, concat(R,S) = 132 bytes
|
||||
default:
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// Convert to asn1
|
||||
if len(sig) != octetSize*2 {
|
||||
return nil, errors.Errorf("keyVault Sign failed: unexpected signature length")
|
||||
}
|
||||
var b cryptobyte.Builder
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R
|
||||
b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S
|
||||
})
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
||||
switch key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
hashFunc := opts.HashFunc()
|
||||
pss, isPSS := opts.(*rsa.PSSOptions)
|
||||
// Random salt lengths are not supported
|
||||
if isPSS &&
|
||||
pss.SaltLength != rsa.PSSSaltLengthAuto &&
|
||||
pss.SaltLength != rsa.PSSSaltLengthEqualsHash &&
|
||||
pss.SaltLength != hashFunc.Size() {
|
||||
return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength)
|
||||
}
|
||||
|
||||
switch h := hashFunc; h {
|
||||
case crypto.SHA256:
|
||||
if isPSS {
|
||||
return keyvault.PS256, nil
|
||||
}
|
||||
return keyvault.RS256, nil
|
||||
case crypto.SHA384:
|
||||
if isPSS {
|
||||
return keyvault.PS384, nil
|
||||
}
|
||||
return keyvault.RS384, nil
|
||||
case crypto.SHA512:
|
||||
if isPSS {
|
||||
return keyvault.PS512, nil
|
||||
}
|
||||
return keyvault.RS512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
switch h := opts.HashFunc(); h {
|
||||
case crypto.SHA256:
|
||||
return keyvault.ES256, nil
|
||||
case crypto.SHA384:
|
||||
return keyvault.ES384, nil
|
||||
case crypto.SHA512:
|
||||
return keyvault.ES512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
default:
|
||||
return "", errors.Errorf("unsupported key type %T", key)
|
||||
}
|
||||
}
|
@ -0,0 +1,352 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
)
|
||||
|
||||
func TestNewSigner(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
jwk := createJWK(t, pub)
|
||||
|
||||
client := mockClient(t)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{
|
||||
Key: jwk,
|
||||
}, nil)
|
||||
client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest)
|
||||
|
||||
var noOptions DefaultOptions
|
||||
type args struct {
|
||||
client KeyVaultClient
|
||||
signingKey string
|
||||
defaults DefaultOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "my-version",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{
|
||||
client: client,
|
||||
vaultBaseURL: "https://my-vault.vault.azure.net/",
|
||||
name: "my-key",
|
||||
version: "my-version",
|
||||
publicKey: pub,
|
||||
}, false},
|
||||
{"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
|
||||
{"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true},
|
||||
{"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true},
|
||||
{"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Public(t *testing.T) {
|
||||
key, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub := key.Public()
|
||||
|
||||
type fields struct {
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want crypto.PublicKey
|
||||
}{
|
||||
{"ok", fields{pub}, pub},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Sign(t *testing.T) {
|
||||
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
||||
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := opts.HashFunc().New()
|
||||
h.Write([]byte("random-data"))
|
||||
sum := h.Sum(nil)
|
||||
|
||||
var sig, resultSig []byte
|
||||
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
||||
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
curveBits := priv.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
// nolint:gocritic
|
||||
resultSig = append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
var b cryptobyte.Builder
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
b.AddASN1BigInt(r)
|
||||
b.AddASN1BigInt(s)
|
||||
})
|
||||
sig, err = b.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
sig, err = key.Sign(rand.Reader, sum, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultSig = sig
|
||||
}
|
||||
|
||||
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
||||
}
|
||||
|
||||
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
||||
p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384)
|
||||
p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512)
|
||||
rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256)
|
||||
rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384)
|
||||
rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512)
|
||||
rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
})
|
||||
rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
})
|
||||
rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
})
|
||||
|
||||
ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := mockClient(t)
|
||||
expects := []struct {
|
||||
name string
|
||||
keyVersion string
|
||||
alg keyvault.JSONWebKeySignatureAlgorithm
|
||||
digest []byte
|
||||
result keyvault.KeyOperationResult
|
||||
err error
|
||||
}{
|
||||
{"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
||||
Result: &p256ResultSig,
|
||||
}, nil},
|
||||
{"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{
|
||||
Result: &p386ResultSig,
|
||||
}, nil},
|
||||
{"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{
|
||||
Result: &p521ResultSig,
|
||||
}, nil},
|
||||
{"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA256ResultSig,
|
||||
}, nil},
|
||||
{"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA384ResultSig,
|
||||
}, nil},
|
||||
{"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA512ResultSig,
|
||||
}, nil},
|
||||
{"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaPSSSHA256ResultSig,
|
||||
}, nil},
|
||||
{"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaPSSSHA384ResultSig,
|
||||
}, nil},
|
||||
{"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaPSSSHA512ResultSig,
|
||||
}, nil},
|
||||
// Errors
|
||||
{"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest},
|
||||
{"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
||||
Result: &rsaSHA256ResultSig,
|
||||
}, nil},
|
||||
{"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{
|
||||
Result: func() *string {
|
||||
v := "😎"
|
||||
return &v
|
||||
}(),
|
||||
}, nil},
|
||||
}
|
||||
for _, e := range expects {
|
||||
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
||||
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
||||
Algorithm: e.alg,
|
||||
Value: &value,
|
||||
}).Return(e.result, e.err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
client KeyVaultClient
|
||||
vaultBaseURL string
|
||||
name string
|
||||
version string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
type args struct {
|
||||
rand io.Reader
|
||||
digest []byte
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, p256Sig, false},
|
||||
{"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{
|
||||
rand.Reader, p384Digest, crypto.SHA384,
|
||||
}, p384Sig, false},
|
||||
{"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{
|
||||
rand.Reader, p521Digest, crypto.SHA512,
|
||||
}, p521Sig, false},
|
||||
{"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
||||
rand.Reader, rsaSHA256Digest, crypto.SHA256,
|
||||
}, rsaSHA256Sig, false},
|
||||
{"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{
|
||||
rand.Reader, rsaSHA384Digest, crypto.SHA384,
|
||||
}, rsaSHA384Sig, false},
|
||||
{"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{
|
||||
rand.Reader, rsaSHA512Digest, crypto.SHA512,
|
||||
}, rsaSHA512Sig, false},
|
||||
{"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
|
||||
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}, rsaPSSSHA256Sig, false},
|
||||
{"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{
|
||||
rand.Reader, rsaPSSSHA384Digest, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
}, rsaPSSSHA384Sig, false},
|
||||
{"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{
|
||||
rand.Reader, rsaPSSSHA512Digest, &rsa.PSSOptions{
|
||||
SaltLength: 64,
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
}, rsaPSSSHA512Sig, false},
|
||||
{"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
||||
rand.Reader, rsaSHA256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
{"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
{"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
{"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{
|
||||
rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{
|
||||
SaltLength: 64,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}, nil, true},
|
||||
{"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{
|
||||
rand.Reader, rsaSHA256Digest, crypto.SHA1,
|
||||
}, nil, true},
|
||||
{"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.MD5,
|
||||
}, nil, true},
|
||||
{"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{
|
||||
rand.Reader, []byte("message"), crypto.Hash(0),
|
||||
}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
client: tt.fields.client,
|
||||
vaultBaseURL: tt.fields.vaultBaseURL,
|
||||
name: tt.fields.name,
|
||||
version: tt.fields.version,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/uri"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// defaultContext returns the default context used in requests to azure.
|
||||
func defaultContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 15*time.Second)
|
||||
}
|
||||
|
||||
// getKeyName returns the uri of the key vault key.
|
||||
func getKeyName(vault, name string, bundle keyvault.KeyBundle) string {
|
||||
if bundle.Key != nil && bundle.Key.Kid != nil {
|
||||
sm := keyIDRegexp.FindAllStringSubmatch(*bundle.Key.Kid, 1)
|
||||
if len(sm) == 1 && len(sm[0]) == 4 {
|
||||
m := sm[0]
|
||||
u := uri.New(Scheme, url.Values{
|
||||
"vault": []string{m[1]},
|
||||
"name": []string{m[2]},
|
||||
})
|
||||
u.RawQuery = url.Values{"version": []string{m[3]}}.Encode()
|
||||
return u.String()
|
||||
}
|
||||
}
|
||||
// Fallback to URI without id.
|
||||
return uri.New(Scheme, url.Values{
|
||||
"vault": []string{vault},
|
||||
"name": []string{name},
|
||||
}).String()
|
||||
}
|
||||
|
||||
// parseKeyName returns the key vault, name and version from URIs like:
|
||||
//
|
||||
// - azurekms:vault=key-vault;name=key-name
|
||||
// - azurekms:vault=key-vault;name=key-name?version=key-id
|
||||
// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true
|
||||
//
|
||||
// The key-id defines the version of the key, if it is not passed the latest
|
||||
// version will be used.
|
||||
//
|
||||
// HSM can also be passed to define the protection level if this is not given in
|
||||
// CreateQuery.
|
||||
func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) {
|
||||
var u *uri.URI
|
||||
|
||||
u, err = uri.ParseWithScheme(Scheme, rawURI)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if name = u.Get("name"); name == "" {
|
||||
err = errors.Errorf("key uri %s is not valid: name is missing", rawURI)
|
||||
return
|
||||
}
|
||||
if vault = u.Get("vault"); vault == "" {
|
||||
if defaults.Vault == "" {
|
||||
name = ""
|
||||
err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI)
|
||||
return
|
||||
}
|
||||
vault = defaults.Vault
|
||||
}
|
||||
if u.Get("hsm") == "" {
|
||||
hsm = (defaults.ProtectionLevel == apiv1.HSM)
|
||||
} else {
|
||||
hsm = u.GetBool("hsm")
|
||||
}
|
||||
|
||||
version = u.Get("version")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func vaultBaseURL(vault string) string {
|
||||
return "https://" + vault + ".vault.azure.net/"
|
||||
}
|
||||
|
||||
func convertKey(key *keyvault.JSONWebKey) (crypto.PublicKey, error) {
|
||||
b, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling key")
|
||||
}
|
||||
var jwk jose.JSONWebKey
|
||||
if err := jwk.UnmarshalJSON(b); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling key")
|
||||
}
|
||||
return jwk.Key, nil
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package azurekms
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
)
|
||||
|
||||
func Test_getKeyName(t *testing.T) {
|
||||
getBundle := func(kid string) keyvault.KeyBundle {
|
||||
return keyvault.KeyBundle{
|
||||
Key: &keyvault.JSONWebKey{
|
||||
Kid: &kid,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type args struct {
|
||||
vault string
|
||||
name string
|
||||
bundle keyvault.KeyBundle
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"ok", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault?version=my-version"},
|
||||
{"ok default", args{"my-vault", "my-key", getBundle("https://my-vault.foo.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok too short", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-version")}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok too long", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version/sign")}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok nil key", args{"my-vault", "my-key", keyvault.KeyBundle{}}, "azurekms:name=my-key;vault=my-vault"},
|
||||
{"ok nil kid", args{"my-vault", "my-key", keyvault.KeyBundle{Key: &keyvault.JSONWebKey{}}}, "azurekms:name=my-key;vault=my-vault"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getKeyName(tt.args.vault, tt.args.name, tt.args.bundle); got != tt.want {
|
||||
t.Errorf("getKeyName() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseKeyName(t *testing.T) {
|
||||
var noOptions DefaultOptions
|
||||
type args struct {
|
||||
rawURI string
|
||||
defaults DefaultOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantVault string
|
||||
wantName string
|
||||
wantVersion string
|
||||
wantHsm bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
|
||||
{"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false},
|
||||
{"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false},
|
||||
{"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false},
|
||||
{"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false},
|
||||
{"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false},
|
||||
{"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false},
|
||||
{"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true},
|
||||
{"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true},
|
||||
{"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true},
|
||||
{"fail empty", args{"", noOptions}, "", "", "", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotVault != tt.wantVault {
|
||||
t.Errorf("parseKeyName() gotVault = %v, want %v", gotVault, tt.wantVault)
|
||||
}
|
||||
if gotName != tt.wantName {
|
||||
t.Errorf("parseKeyName() gotName = %v, want %v", gotName, tt.wantName)
|
||||
}
|
||||
if gotVersion != tt.wantVersion {
|
||||
t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion)
|
||||
}
|
||||
if gotHsm != tt.wantHsm {
|
||||
t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue