smallstep-certificates/ca/tls_test.go
Mariano Cano d872f09910 Use mTLS by default on SDK methods.
Add options to modify the tls.Config for different configurations.
Fixes #7
2018-11-21 13:31:09 -08:00

504 lines
14 KiB
Go

package ca
import (
"bytes"
"context"
"crypto"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/cli/crypto/randutil"
stepJOSE "github.com/smallstep/cli/jose"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
func generateOTT(subject string) string {
now := time.Now()
jwk, err := stepJOSE.ParseKey("testdata/secrets/ott_mariano_priv.jwk", stepJOSE.WithPassword([]byte("password")))
if err != nil {
panic(err)
}
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
if err != nil {
panic(err)
}
id, err := randutil.ASCII(64)
if err != nil {
panic(err)
}
cl := jwt.Claims{
ID: id,
Subject: subject,
Issuer: "mariano",
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"https://127.0.0.1:0/sign"},
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
if err != nil {
panic(err)
}
return raw
}
func startTestServer(tlsConfig *tls.Config, handler http.Handler) *httptest.Server {
srv := httptest.NewUnstartedServer(handler)
srv.TLS = tlsConfig
srv.StartTLS()
// Force the use of GetCertificate on IPs
srv.TLS.Certificates = nil
return srv
}
func startCATestServer() *httptest.Server {
config, err := authority.LoadConfiguration("testdata/ca.json")
if err != nil {
panic(err)
}
ca, err := New(config)
if err != nil {
panic(err)
}
// Use a httptest.Server instead
return startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
}
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
srv := startCATestServer()
defer srv.Close()
return signDuration(srv, domain, 0)
}
func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) {
req, pk, err := CreateSignRequest(generateOTT(domain))
if err != nil {
panic(err)
}
if duration > 0 {
req.NotBefore = time.Now()
req.NotAfter = req.NotBefore.Add(duration)
}
client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
panic(err)
}
sr, err := client.Sign(req)
if err != nil {
panic(err)
}
return client, sr, pk
}
func serverHandler(t *testing.T, clientDomain string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
}
w.Write([]byte("ok"))
})
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain"
client, sr, pk := sign("127.0.0.1")
// Create mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Create TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
tests := []struct {
name string
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
wantErr map[string]bool
}{
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
return nil
}
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
return &http.Client{}
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli == nil {
return
}
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := cli.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
return
}
if wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
})
}
}
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// Start CA
ca := startCATestServer()
defer ca.Close()
clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
// Start mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Start TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
// Transport
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// No client cert
root, err := RootCertificate(sr)
if err != nil {
t.Fatalf("RootCertificate() error = %v", err)
}
tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr3, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tr3.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
wantErr map[string]bool
}{
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
srvTLS.URL + "/no-cert": false,
srvMTLS.URL + "/no-cert": true,
}},
{"fail with default", &http.Client{}, map[string]bool{
srvTLS.URL + "/no-cert": true,
srvMTLS.URL + "/no-cert": true,
}},
}
// To count different cert fingerprints
fingerprints := map[string]struct{}{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
}
// Wait for renewal 40s == 1m-1m/3
log.Printf("Sleeping for %s ...\n", 40*time.Second)
time.Sleep(40 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}
if l := len(fingerprints); l != 4 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
}
func TestCertificate(t *testing.T) {
cert := parseCertificate(certPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: cert},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, cert, false},
{"fail", &api.SignResponse{}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Certificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("Certificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Certificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestIntermediateCertificate(t *testing.T) {
intermediate := parseCertificate(rootPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: intermediate},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, intermediate, false},
{"fail", &api.SignResponse{}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := IntermediateCertificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("IntermediateCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("IntermediateCertificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestRootCertificateCertificate(t *testing.T) {
root := parseCertificate(rootPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{
{root, root},
}},
}
noTLS := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, root, false},
{"fail", &api.SignResponse{}, nil, true},
{"no tls", noTLS, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RootCertificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("RootCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RootCertificate() = %v, want %v", got, tt.want)
}
})
}
}