From 3f71b8debd5f6e1b1fad228e92ae5f15ab1a3e95 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 12 Dec 2019 12:48:34 -0800 Subject: [PATCH] Add mTLS test for identity client. --- ca/identity/client_test.go | 63 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index 9ab14e94..8ff27bba 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -5,11 +5,74 @@ import ( "crypto/x509" "io/ioutil" "net/http" + "net/http/httptest" "net/url" "reflect" "testing" ) +func TestClient(t *testing.T) { + oldIdentityFile := IdentityFile + oldDefaultsFile := DefaultsFile + defer func() { + IdentityFile = oldIdentityFile + DefaultsFile = oldDefaultsFile + }() + + IdentityFile = "testdata/config/identity.json" + DefaultsFile = "testdata/config/defaults.json" + + client, err := LoadClient() + if err != nil { + t.Fatal(err) + } + + okServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer okServer.Close() + + crt, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/secrets/server_key") + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") + if err != nil { + t.Fatal(err) + } + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(b) + + okServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{crt}, + ClientCAs: pool, + ClientAuth: tls.VerifyClientCertIfGiven, + } + okServer.StartTLS() + + badServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + defer badServer.Close() + + if resp, err := client.Get(okServer.URL); err != nil { + t.Errorf("client.Get() error = %v", err) + } else { + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("client.Get() = %d, want %d", resp.StatusCode, http.StatusOK) + } + } + + if _, err := client.Get(badServer.URL); err == nil { + t.Errorf("client.Get() error = %v, wantErr true", err) + } +} + func TestClient_ResolveReference(t *testing.T) { type fields struct { CaURL *url.URL