smallstep-certificates/ca/identity/client_test.go
2022-03-28 14:55:39 -07:00

262 lines
7.3 KiB
Go

package identity
import (
"crypto/tls"
"crypto/x509"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"sort"
"testing"
)
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile
defer func() {
IdentityFile = oldIdentityFile
DefaultsFile = oldDefaultsFile
}()
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient()
if err != nil {
t.Fatal(err)
}
okServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
}
}))
defer okServer.Close()
crt, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/secrets/server_key")
if err != nil {
t.Fatal(err)
}
b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil {
t.Fatal(err)
}
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(b)
okServer.TLS = &tls.Config{
Certificates: []tls.Certificate{crt},
ClientCAs: pool,
ClientAuth: tls.VerifyClientCertIfGiven,
}
okServer.StartTLS()
badServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
defer badServer.Close()
if resp, err := client.Get(okServer.URL); err != nil {
t.Errorf("client.Get() error = %v", err)
} else {
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("client.Get() = %d, want %d", resp.StatusCode, http.StatusOK)
}
}
if _, err := client.Get(badServer.URL); err == nil {
t.Errorf("client.Get() error = %v, wantErr true", err)
}
}
func TestClient_ResolveReference(t *testing.T) {
type fields struct {
CaURL *url.URL
}
type args struct {
ref *url.URL
}
tests := []struct {
name string
fields fields
args args
want *url.URL
}{
{"ok", fields{&url.URL{Scheme: "https", Host: "localhost"}}, args{&url.URL{Path: "/foo"}}, &url.URL{Scheme: "https", Host: "localhost", Path: "/foo"}},
{"ok", fields{&url.URL{Scheme: "https", Host: "localhost", Path: "/bar"}}, args{&url.URL{Path: "/foo"}}, &url.URL{Scheme: "https", Host: "localhost", Path: "/foo"}},
{"ok", fields{&url.URL{Scheme: "https", Host: "localhost"}}, args{&url.URL{Path: "/foo", RawQuery: "foo=bar"}}, &url.URL{Scheme: "https", Host: "localhost", Path: "/foo", RawQuery: "foo=bar"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
CaURL: tt.fields.CaURL,
}
if got := c.ResolveReference(tt.args.ref); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Client.ResolveReference() = %v, want %v", got, tt.want)
}
})
}
}
func TestLoadClient(t *testing.T) {
oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile
defer func() {
IdentityFile = oldIdentityFile
DefaultsFile = oldDefaultsFile
}()
crt, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
if err != nil {
t.Fatal(err)
}
b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil {
t.Fatal(err)
}
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(b)
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{
Certificates: []tls.Certificate{crt},
RootCAs: pool,
}
expected := &Client{
CaURL: &url.URL{Scheme: "https", Host: "127.0.0.1"},
Client: &http.Client{
Transport: tr,
},
}
tests := []struct {
name string
prepare func()
want *Client
wantErr bool
}{
{"ok", func() {
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false},
{"fail identity", func() {
IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
{"fail identity", func() {
IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
{"fail defaults", func() {
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true},
{"fail defaults", func() {
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true},
{"fail ca", func() {
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true},
{"fail root", func() {
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true},
{"fail type", func() {
IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.prepare()
got, err := LoadClient()
if (err != nil) != tt.wantErr {
t.Errorf("LoadClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want == nil {
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
}
} else {
gotTransport := got.Client.Transport.(*http.Transport)
wantTransport := tt.want.Client.Transport.(*http.Transport)
switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate")
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
if err != nil {
t.Errorf("LoadClient() GetClientCertificate error = %v", err)
} else if !reflect.DeepEqual(*crt, wantTransport.TLSClientConfig.Certificates[0]) {
t.Errorf("LoadClient() GetClientCertificate crt = %#v, want %#v", *crt, wantTransport.TLSClientConfig.Certificates[0])
}
}
}
})
}
}
func Test_defaultsConfig_Validate(t *testing.T) {
type fields struct {
CaURL string
Root string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{"https://127.0.0.1", "root_ca.crt"}, false},
{"fail ca-url", fields{"", "root_ca.crt"}, true},
{"fail root", fields{"https://127.0.0.1", ""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &defaultsConfig{
CaURL: tt.fields.CaURL,
Root: tt.fields.Root,
}
if err := c.Validate(); (err != nil) != tt.wantErr {
t.Errorf("defaultsConfig.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// nolint:staticcheck,gocritic
func equalPools(a, b *x509.CertPool) bool {
if reflect.DeepEqual(a, b) {
return true
}
subjects := a.Subjects()
sA := make([]string, len(subjects))
for i := range subjects {
sA[i] = string(subjects[i])
}
subjects = b.Subjects()
sB := make([]string, len(subjects))
for i := range subjects {
sB[i] = string(subjects[i])
}
sort.Strings(sA)
sort.Strings(sB)
return reflect.DeepEqual(sA, sB)
}