|
|
|
@ -113,6 +113,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|
|
|
|
clientDomain := "test.domain"
|
|
|
|
|
// Create server with given tls.Config
|
|
|
|
|
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
if req.RequestURI != "/no-cert" {
|
|
|
|
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Error("http.Request.TLS does not have peer certificates")
|
|
|
|
@ -128,15 +129,18 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|
|
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
w.Write([]byte("ok"))
|
|
|
|
|
}))
|
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
path string
|
|
|
|
|
wantErr bool
|
|
|
|
|
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
|
|
|
|
}{
|
|
|
|
|
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
tr, err := client.Transport(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("Client.Transport() error = %v", err)
|
|
|
|
@ -146,7 +150,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|
|
|
|
Transport: tr,
|
|
|
|
|
}
|
|
|
|
|
}},
|
|
|
|
|
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
|
|
|
@ -161,6 +165,28 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|
|
|
|
Transport: tr,
|
|
|
|
|
}
|
|
|
|
|
}},
|
|
|
|
|
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
root, err := RootCertificate(sr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("RootCertificate() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
tlsConfig := getDefaultTLSConfig(sr)
|
|
|
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
|
|
|
tlsConfig.RootCAs.AddCert(root)
|
|
|
|
|
|
|
|
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("getDefaultTransport() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return &http.Client{
|
|
|
|
|
Transport: tr,
|
|
|
|
|
}
|
|
|
|
|
}},
|
|
|
|
|
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
return &http.Client{}
|
|
|
|
|
}},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
@ -168,9 +194,13 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|
|
|
|
client, sr, pk := sign(clientDomain)
|
|
|
|
|
cli := tt.getClient(t, client, sr, pk)
|
|
|
|
|
if cli != nil {
|
|
|
|
|
resp, err := cli.Get(srv.URL)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("http.Client.Get() error = %v", err)
|
|
|
|
|
resp, err := cli.Get(srv.URL + tt.path)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
|
|
@ -301,6 +331,230 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
|
|
|
|
|
client, sr, pk := sign("127.0.0.1")
|
|
|
|
|
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
clientDomain := "test.domain"
|
|
|
|
|
// Create server with given tls.Config
|
|
|
|
|
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
if req.RequestURI != "/no-cert" {
|
|
|
|
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Error("http.Request.TLS does not have peer certificates")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
w.Write([]byte("ok"))
|
|
|
|
|
}))
|
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
path string
|
|
|
|
|
wantErr bool
|
|
|
|
|
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
|
|
|
|
}{
|
|
|
|
|
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
tr, err := client.Transport(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("Client.Transport() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return &http.Client{
|
|
|
|
|
Transport: tr,
|
|
|
|
|
}
|
|
|
|
|
}},
|
|
|
|
|
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("getDefaultTransport() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return &http.Client{
|
|
|
|
|
Transport: tr,
|
|
|
|
|
}
|
|
|
|
|
}},
|
|
|
|
|
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
|
|
|
root, err := RootCertificate(sr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("RootCertificate() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
tlsConfig := getDefaultTLSConfig(sr)
|
|
|
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
|
|
|
tlsConfig.RootCAs.AddCert(root)
|
|
|
|
|
|
|
|
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("getDefaultTransport() error = %v", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return &http.Client{
|
|
|
|
|
Transport: tr,
|
|
|
|
|
}
|
|
|
|
|
}},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
client, sr, pk := sign(clientDomain)
|
|
|
|
|
cli := tt.getClient(t, client, sr, pk)
|
|
|
|
|
if cli != nil {
|
|
|
|
|
resp, err := cli.Get(srv.URL + tt.path)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(b, []byte("ok")) {
|
|
|
|
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
|
|
|
|
if testing.Short() {
|
|
|
|
|
t.Skip("skipping test in short mode.")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Start CA
|
|
|
|
|
ca := startCATestServer()
|
|
|
|
|
defer ca.Close()
|
|
|
|
|
|
|
|
|
|
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
|
|
|
|
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
clientDomain := "test.domain"
|
|
|
|
|
fingerprints := make(map[string]struct{})
|
|
|
|
|
|
|
|
|
|
// Create server with given tls.Config
|
|
|
|
|
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Error("http.Request.TLS does not have peer certificates")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
|
|
|
|
w.Write([]byte("fail"))
|
|
|
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
// Add serial number to check rotation
|
|
|
|
|
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
|
|
|
|
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
|
|
|
|
|
w.Write([]byte("ok"))
|
|
|
|
|
}))
|
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
|
|
// Clients: transport and tlsConfig
|
|
|
|
|
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
|
|
|
|
tr1, err := client.Transport(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Client.Transport() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
|
|
|
|
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
tr2, err := getDefaultTransport(tlsConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("getDefaultTransport() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Disable keep alives to force TLS handshake
|
|
|
|
|
tr1.DisableKeepAlives = true
|
|
|
|
|
tr2.DisableKeepAlives = true
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
client *http.Client
|
|
|
|
|
}{
|
|
|
|
|
{"with transport", &http.Client{Transport: tr1}},
|
|
|
|
|
{"with tlsConfig", &http.Client{Transport: tr2}},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
resp, err := tt.client.Get(srv.URL)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("http.Client.Get() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(b, []byte("ok")) {
|
|
|
|
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if l := len(fingerprints); l != 2 {
|
|
|
|
|
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Wait for renewal 40s == 1m-1m/3
|
|
|
|
|
log.Printf("Sleeping for %s ...\n", 40*time.Second)
|
|
|
|
|
time.Sleep(40 * time.Second)
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run("renewed "+tt.name, func(t *testing.T) {
|
|
|
|
|
resp, err := tt.client.Get(srv.URL)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("http.Client.Get() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(b, []byte("ok")) {
|
|
|
|
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if l := len(fingerprints); l != 4 {
|
|
|
|
|
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCertificate(t *testing.T) {
|
|
|
|
|
cert := parseCertificate(certPEM)
|
|
|
|
|
ok := &api.SignResponse{
|
|
|
|
|