smallstep-certificates/kms/sshagentkms/sshagentkms_test.go
2021-11-13 01:30:03 +01:00

610 lines
18 KiB
Go

package sshagentkms
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"net"
"os"
"os/exec"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"github.com/smallstep/certificates/kms/apiv1"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"go.step.sm/crypto/pemutil"
)
// Some helpers with inspiration from crypto/ssh/agent/client_test.go
// startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
func startOpenSSHAgent(t *testing.T) (client agent.Agent, socket string, cleanup func()) {
/* Always test with OpenSSHAgent
if testing.Short() {
// ssh-agent is not always available, and the key
// types supported vary by platform.
t.Skip("skipping test due to -short")
}
*/
bin, err := exec.LookPath("ssh-agent")
if err != nil {
t.Skip("could not find ssh-agent")
}
cmd := exec.Command(bin, "-s")
cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior.
cmd.Stderr = new(bytes.Buffer)
out, err := cmd.Output()
if err != nil {
t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
}
// Output looks like:
//
// SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
// SSH_AGENT_PID=15542; export SSH_AGENT_PID;
// echo Agent pid 15542;
fields := bytes.Split(out, []byte(";"))
line := bytes.SplitN(fields[0], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AUTH_SOCK" {
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
}
socket = string(line[1])
line = bytes.SplitN(fields[2], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AGENT_PID" {
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
}
pidStr := line[1]
pid, err := strconv.Atoi(string(pidStr))
if err != nil {
t.Fatalf("Atoi(%q): %v", pidStr, err)
}
conn, err := net.Dial("unix", string(socket))
if err != nil {
t.Fatalf("net.Dial: %v", err)
}
ac := agent.NewClient(conn)
return ac, socket, func() {
proc, _ := os.FindProcess(pid)
if proc != nil {
proc.Kill()
}
conn.Close()
os.RemoveAll(filepath.Dir(socket))
}
}
func startAgent(t *testing.T, sshagent agent.Agent) (client agent.Agent, cleanup func()) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
go agent.ServeAgent(sshagent, c2)
return agent.NewClient(c1), func() {
c1.Close()
c2.Close()
}
}
// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
func startKeyringAgent(t *testing.T) (client agent.Agent, cleanup func()) {
return startAgent(t, agent.NewKeyring())
}
type startTestAgentFunc func(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent)
func startTestOpenSSHAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
sshagent, _, cleanup := startOpenSSHAgent(t)
for _, keyToAdd := range keysToAdd {
err := sshagent.Add(keyToAdd)
if err != nil {
t.Fatalf("sshagent.add: %v", err)
}
}
t.Cleanup(cleanup)
//testAgentInterface(t, sshagent, key, cert, lifetimeSecs)
return sshagent
}
func startTestKeyringAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
sshagent, cleanup := startKeyringAgent(t)
for _, keyToAdd := range keysToAdd {
err := sshagent.Add(keyToAdd)
if err != nil {
t.Fatalf("sshagent.add: %v", err)
}
}
t.Cleanup(cleanup)
//testAgentInterface(t, agent, key, cert, lifetimeSecs)
return sshagent
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := netListener()
if err != nil {
return nil, nil, err
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1, c2, nil
}
// netListener creates a localhost network listener.
func netListener() (net.Listener, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, err
}
}
return listener, nil
}
func TestNew(t *testing.T) {
comment := "Key from OpenSSHAgent"
// Ensure we don't "inherit" any SSH_AUTH_SOCK
os.Unsetenv("SSH_AUTH_SOCK")
sshagent, socket, cleanup := startOpenSSHAgent(t)
os.Setenv("SSH_AUTH_SOCK", socket)
t.Cleanup(func() {
os.Unsetenv("SSH_AUTH_SOCK")
cleanup()
})
// Test that we can't find any signers in the agent before we have loaded them
t.Run("No keys with OpenSSHAgent", func(t *testing.T) {
kms, err := New(context.Background(), apiv1.Options{})
if kms == nil || err != nil {
t.Errorf("New() = %v, %v", kms, err)
}
signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
if err == nil || signer != nil {
t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
}
})
// Load ssh test fixtures
b, err := os.ReadFile("testdata/ssh")
if err != nil {
t.Fatal(err)
}
privateKey, err := ssh.ParseRawPrivateKey(b)
if err != nil {
t.Fatal(err)
}
// And add that key to the agent
err = sshagent.Add(agent.AddedKey{PrivateKey: privateKey, Comment: comment})
if err != nil {
t.Fatalf("sshagent.add: %v", err)
}
// And test that we can find it when it's loaded
t.Run("Keys with OpenSSHAgent", func(t *testing.T) {
kms, err := New(context.Background(), apiv1.Options{})
if kms == nil || err != nil {
t.Errorf("New() = %v, %v", kms, err)
}
signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
if err != nil || signer == nil {
t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
}
})
}
func TestNewFromAgent(t *testing.T) {
type args struct {
ctx context.Context
opts apiv1.Options
}
tests := []struct {
name string
args args
sshagentstarter startTestAgentFunc
wantErr bool
}{
{"ok OpenSSHAgent", args{context.Background(), apiv1.Options{}}, startTestOpenSSHAgent, false},
{"ok KeyringAgent", args{context.Background(), apiv1.Options{}}, startTestKeyringAgent, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewFromAgent(tt.args.ctx, tt.args.opts, tt.sshagentstarter(t))
if (err != nil) != tt.wantErr {
t.Errorf("NewFromAgent() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got == nil {
t.Errorf("NewFromAgent() = %v", got)
}
})
}
}
func TestSSHAgentKMS_Close(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{"ok", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SSHAgentKMS{}
if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSSHAgentKMS_CreateSigner(t *testing.T) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
pemBlock, err := pemutil.Serialize(pk)
if err != nil {
t.Fatal(err)
}
pemBlockPassword, err := pemutil.Serialize(pk, pemutil.WithPassword([]byte("pass")))
if err != nil {
t.Fatal(err)
}
// Read and decode file using standard packages
b, err := os.ReadFile("testdata/priv.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
if err != nil {
t.Fatal(err)
}
pk2, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
// Create a public PEM
b, err = x509.MarshalPKIXPublicKey(pk.Public())
if err != nil {
t.Fatal(err)
}
pub := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: b,
})
// Load ssh test fixtures
sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
if err != nil {
t.Fatal(err)
}
_, comment, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyStr)
if err != nil {
t.Fatal(err)
}
b, err = os.ReadFile("testdata/ssh")
if err != nil {
t.Fatal(err)
}
privateKey, err := ssh.ParseRawPrivateKey(b)
if err != nil {
t.Fatal(err)
}
sshPrivateKey, err := ssh.NewSignerFromKey(privateKey)
if err != nil {
t.Fatal(err)
}
wrappedSSHPrivateKey := NewWrappedSignerFromSSHSigner(sshPrivateKey)
type args struct {
req *apiv1.CreateSignerRequest
}
tests := []struct {
name string
args args
want crypto.Signer
wantErr bool
}{
{"signer", args{&apiv1.CreateSignerRequest{Signer: pk}}, pk, false},
{"pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlock)}}, pk, false},
{"pem password", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, pk, false},
{"file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("pass")}}, pk2, false},
{"sshagent", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment}}, wrappedSSHPrivateKey, false},
{"sshagent Nonexistant", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:Nonexistant"}}, nil, true},
{"fail", args{&apiv1.CreateSignerRequest{}}, nil, true},
{"fail bad pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: []byte("bad pem")}}, nil, true},
{"fail bad password", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("bad-pass")}}, nil, true},
{"fail not a signer", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pub}}, nil, true},
{"fail not a signer from file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/pub.pem"}}, nil, true},
{"fail missing", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/missing"}}, nil, true},
}
starters := []struct {
name string
starter startTestAgentFunc
}{
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
{"startTestKeyringAgent", startTestKeyringAgent},
}
for _, starter := range starters {
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: privateKey, Comment: comment}))
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return
}
// nolint:gocritic
switch s := got.(type) {
case *WrappedSSHSigner:
gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
wantPkS := string(sshPubKeyStr)
if !reflect.DeepEqual(gotPkS, wantPkS) {
t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", gotPkS, wantPkS)
t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", gotPkS, wantPkS)
}
default:
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", got, tt.want)
t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", got, tt.want)
}
}
})
}
}
}
/*
func restoreGenerateKey() func() {
oldGenerateKey := generateKey
return func() {
generateKey = oldGenerateKey
}
}
*/
/*
func TestSSHAgentKMS_CreateKey(t *testing.T) {
fn := restoreGenerateKey()
defer fn()
p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
edpub, edpriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.CreateKeyRequest
}
type params struct {
kty string
crv string
size int
}
tests := []struct {
name string
args args
generateKey func() (interface{}, interface{}, error)
want *apiv1.CreateKeyResponse
wantParams params
wantErr bool
}{
{"p256", args{&apiv1.CreateKeyRequest{Name: "p256", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, &apiv1.CreateKeyResponse{Name: "p256", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
{"rsa", args{&apiv1.CreateKeyRequest{Name: "rsa3072", SignatureAlgorithm: apiv1.SHA256WithRSA}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa3072", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 0}, false},
{"rsa2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
{"rsaPSS2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, Bits: 2048}}, func() (interface{}, interface{}, error) {
return rsa2048.Public(), rsa2048, nil
}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
{"ed25519", args{&apiv1.CreateKeyRequest{Name: "ed25519", SignatureAlgorithm: apiv1.PureEd25519}}, func() (interface{}, interface{}, error) {
return edpub, edpriv, nil
}, &apiv1.CreateKeyResponse{Name: "ed25519", PublicKey: edpub, PrivateKey: edpriv, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: edpriv}}, params{"OKP", "Ed25519", 0}, false},
{"default", args{&apiv1.CreateKeyRequest{Name: "default"}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, &apiv1.CreateKeyResponse{Name: "default", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
{"fail algorithm", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, func() (interface{}, interface{}, error) {
return p256.Public(), p256, nil
}, nil, params{}, true},
{"fail generate key", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return nil, nil, fmt.Errorf("an error")
}, nil, params{"EC", "P-256", 0}, true},
{"fail no signer", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
return 1, 2, nil
}, nil, params{"EC", "P-256", 0}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &SSHAgentKMS{}
generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
if tt.wantParams.kty != kty {
t.Errorf("GenerateKey() kty = %s, want %s", kty, tt.wantParams.kty)
}
if tt.wantParams.crv != crv {
t.Errorf("GenerateKey() crv = %s, want %s", crv, tt.wantParams.crv)
}
if tt.wantParams.size != size {
t.Errorf("GenerateKey() size = %d, want %d", size, tt.wantParams.size)
}
return tt.generateKey()
}
got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHAgentKMS.CreateKey() = %v, want %v", got, tt.want)
}
})
}
}
*/
func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
b, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(b)
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
// Load ssh test fixtures
b, err = os.ReadFile("testdata/ssh.pub")
if err != nil {
t.Fatal(err)
}
sshPubKey, comment, _, _, err := ssh.ParseAuthorizedKey(b)
if err != nil {
t.Fatal(err)
}
b, err = os.ReadFile("testdata/ssh")
if err != nil {
t.Fatal(err)
}
// crypto.PrivateKey
sshPrivateKey, err := ssh.ParseRawPrivateKey(b)
if err != nil {
t.Fatal(err)
}
type args struct {
req *apiv1.GetPublicKeyRequest
}
tests := []struct {
name string
args args
want crypto.PublicKey
wantErr bool
}{
{"key", args{&apiv1.GetPublicKeyRequest{Name: "testdata/pub.pem"}}, pub, false},
{"cert", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.crt"}}, pub, false},
{"sshagent", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:" + comment}}, sshPubKey, false},
{"sshagent Nonexistant", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:Nonexistant"}}, nil, true},
{"fail not exists", args{&apiv1.GetPublicKeyRequest{Name: "testdata/missing"}}, nil, true},
{"fail type", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.key"}}, nil, true},
}
starters := []struct {
name string
starter startTestAgentFunc
}{
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
{"startTestKeyringAgent", startTestKeyringAgent},
}
for _, starter := range starters {
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: sshPrivateKey, Comment: comment}))
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
// nolint:gocritic
switch tt.want.(type) {
case ssh.PublicKey:
// If we want a ssh.PublicKey, protote got to a
got, err = ssh.NewPublicKey(got)
if err != nil {
t.Fatal(err)
}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHAgentKMS.GetPublicKey() = %T, want %T", got, tt.want)
t.Errorf("SSHAgentKMS.GetPublicKey() = %v, want %v", got, tt.want)
}
})
}
}
}
func TestSSHAgentKMS_CreateKey(t *testing.T) {
starters := []struct {
name string
starter startTestAgentFunc
}{
{"startTestOpenSSHAgent", startTestOpenSSHAgent},
{"startTestKeyringAgent", startTestKeyringAgent},
}
for _, starter := range starters {
k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t))
if err != nil {
t.Fatal(err)
}
t.Run(starter.name+"/CreateKey", func(t *testing.T) {
got, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: "sshagentkms:0",
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
})
if got != nil {
t.Error("SSHAgentKMS.CreateKey() shoudn't return a value")
}
if err == nil {
t.Error("SSHAgentKMS.CreateKey() didn't return a value")
}
})
}
}