smallstep-certificates/ca/tls_test.go
Mariano Cano d739aab345
Define BaseContext before starting the server in tests
If the http.Server BaseContext is not define before the start of the
server, it might not be properly set depending on the goroutine
scheduler. This was causing random errors on CI.
2023-08-17 12:56:26 -07:00

513 lines
14 KiB
Go

package ca
import (
"bytes"
"context"
"crypto"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
)
func generateOTT(subject string) string {
now := time.Now()
jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
if err != nil {
panic(err)
}
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
if err != nil {
panic(err)
}
id, err := randutil.ASCII(64)
if err != nil {
panic(err)
}
cl := struct {
jose.Claims
SANS []string `json:"sans"`
}{
Claims: jose.Claims{
ID: id,
Subject: subject,
Issuer: "mariano",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"https://127.0.0.1:0/sign"},
},
SANS: []string{subject},
}
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
if err != nil {
panic(err)
}
return raw
}
func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler http.Handler) *httptest.Server {
srv := httptest.NewUnstartedServer(handler)
srv.TLS = tlsConfig
// Base context MUST be set before the start of the server
srv.Config.BaseContext = func(l net.Listener) context.Context {
return baseContext
}
srv.StartTLS()
// Force the use of GetCertificate on IPs
srv.TLS.Certificates = nil
return srv
}
func startCATestServer() *httptest.Server {
config, err := authority.LoadConfiguration("testdata/ca.json")
if err != nil {
panic(err)
}
ca, err := New(config)
if err != nil {
panic(err)
}
// Use a httptest.Server instead
baseContext := buildContext(ca.auth, nil, nil, nil)
srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler)
return srv
}
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
srv := startCATestServer()
defer srv.Close()
return signDuration(srv, domain, 0)
}
func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) {
req, pk, err := CreateSignRequest(generateOTT(domain))
if err != nil {
panic(err)
}
if duration > 0 {
req.NotBefore = api.NewTimeDuration(time.Now())
req.NotAfter = api.NewTimeDuration(req.NotBefore.Time().Add(duration))
}
client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
panic(err)
}
sr, err := client.Sign(req)
if err != nil {
panic(err)
}
return client, sr, pk
}
func serverHandler(t *testing.T, clientDomain string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
}
w.Write([]byte("ok"))
})
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain"
client, sr, pk := sign("127.0.0.1")
// Create mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Create TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
tests := []struct {
name string
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
wantErr map[string]bool
}{
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
return &http.Client{
Transport: getDefaultTransport(tlsConfig),
}
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
return nil
}
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
return &http.Client{
Transport: getDefaultTransport(tlsConfig),
}
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
return &http.Client{}
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli == nil {
return
}
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := cli.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
return
}
if wantErr {
return
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("io.ReadAll() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
})
}
}
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
reset := setMinCertDuration(1 * time.Second)
defer reset()
// Start CA
ca := startCATestServer()
defer ca.Close()
clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second)
// Start mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Start TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
// Transport
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2 := getDefaultTransport(tlsConfig)
// No client cert
root, err := RootCertificate(sr)
if err != nil {
t.Fatalf("RootCertificate() error = %v", err)
}
tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr3 := getDefaultTransport(tlsConfig)
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tr3.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
wantErr map[string]bool
}{
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
srvTLS.URL + "/no-cert": false,
srvMTLS.URL + "/no-cert": true,
}},
{"fail with default", &http.Client{}, map[string]bool{
srvTLS.URL + "/no-cert": true,
srvMTLS.URL + "/no-cert": true,
}},
}
// To count different cert fingerprints
fingerprints := map[string]struct{}{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
t.Errorf("io.ReadAll() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
}
// Wait for renewal
log.Printf("Sleeping for %s ...\n", 5*time.Second)
time.Sleep(5 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
t.Errorf("io.ReadAll() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}
if l := len(fingerprints); l != 4 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
}
func TestCertificate(t *testing.T) {
cert := parseCertificate(certPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: cert},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: cert},
{Certificate: parseCertificate(rootPEM)},
},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, cert, false},
{"fail", &api.SignResponse{}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Certificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("Certificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Certificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestIntermediateCertificate(t *testing.T) {
intermediate := parseCertificate(rootPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: intermediate},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: intermediate},
},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, intermediate, false},
{"fail", &api.SignResponse{}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := IntermediateCertificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("IntermediateCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("IntermediateCertificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestRootCertificateCertificate(t *testing.T) {
root := parseCertificate(rootPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{
{root, root},
}},
}
noTLS := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, root, false},
{"fail", &api.SignResponse{}, nil, true},
{"no tls", noTLS, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RootCertificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("RootCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RootCertificate() = %v, want %v", got, tt.want)
}
})
}
}