diff --git a/ca/client_test.go b/ca/client_test.go index 6fe8a135..5fd11179 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -9,15 +9,14 @@ import ( "encoding/json" "encoding/pem" "errors" - "fmt" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "time" - sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" @@ -26,6 +25,7 @@ import ( "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -107,52 +107,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` ) -func mustKey() *ecdsa.PrivateKey { +func mustKey(t *testing.T) *ecdsa.PrivateKey { + t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - panic(err) - } + require.NoError(t, err) return priv } -func parseCertificate(data string) *x509.Certificate { +func parseCertificate(t *testing.T, data string) *x509.Certificate { + t.Helper() block, _ := pem.Decode([]byte(data)) if block == nil { - panic("failed to parse certificate PEM") + require.Fail(t, "failed to parse certificate PEM") + return nil } cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - panic("failed to parse certificate: " + err.Error()) - } + require.NoError(t, err, "failed to parse certificate") return cert } -func parseCertificateRequest(string) *x509.CertificateRequest { +func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest { + t.Helper() block, _ := pem.Decode([]byte(csrPEM)) if block == nil { - panic("failed to parse certificate request PEM") + require.Fail(t, "failed to parse certificate request PEM") + return nil } csr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - panic("failed to parse certificate request: " + err.Error()) - } + require.NoError(t, err, "failed to parse certificate request") return csr } func equalJSON(t *testing.T, a, b interface{}) bool { + t.Helper() if reflect.DeepEqual(a, b) { return true } + ab, err := json.Marshal(a) - if err != nil { - t.Error(err) - return false - } + require.NoError(t, err) + bb, err := json.Marshal(b) - if err != nil { - t.Error(err) - return false - } + require.NoError(t, err) + return bytes.Equal(ab, bb) } @@ -177,32 +174,23 @@ func TestClient_Version(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() - if (err != nil) != tt.wantErr { - t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Version() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Version() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -227,40 +215,30 @@ func TestClient_Health(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Health() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Health() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Root(t *testing.T) { ok := &api.RootResponse{ - RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, } tests := []struct { @@ -281,10 +259,7 @@ func TestClient_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { expected := "/root/" + tt.shasum @@ -295,37 +270,31 @@ func TestClient_Root(t *testing.T) { }) got, err := c.Root(tt.shasum) - if (err != nil) != tt.wantErr { - t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Root() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Root() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Sign(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.SignRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, OTT: "the-ott", NotBefore: api.NewTimeDuration(time.Now()), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), @@ -351,16 +320,13 @@ func TestClient_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.SignRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - sassert.Fatal(t, ok, "response expected to be error type") + require.True(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -376,23 +342,16 @@ func TestClient_Sign(t *testing.T) { }) got, err := c.Sign(tt.request) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Sign() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Sign() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -423,16 +382,13 @@ func TestClient_Revoke(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.RevokeRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - sassert.Fatal(t, ok, "response expected to be error type") + require.True(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -448,34 +404,27 @@ func TestClient_Revoke(t *testing.T) { }) got, err := c.Revoke(tt.request, nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Revoke() = %v, want nil", got) - } - sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Renew(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -498,49 +447,38 @@ func TestClient_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Renew() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Renew() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_RenewWithToken(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -563,10 +501,7 @@ func TestClient_RenewWithToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Header.Get("Authorization") != "Bearer token" { @@ -577,44 +512,36 @@ func TestClient_RenewWithToken(t *testing.T) { }) got, err := c.RenewWithToken("token") - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.RenewWithToken() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.RekeyRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, } tests := []struct { @@ -637,38 +564,27 @@ func TestClient_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Renew() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Renew() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -700,10 +616,7 @@ func TestClient_Provisioners(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.RequestURI != tt.expectedURI { @@ -713,22 +626,16 @@ func TestClient_Provisioners(t *testing.T) { }) got, err := c.Provisioners(tt.args...) - if (err != nil) != tt.wantErr { - t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg)) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Provisioners() = %v, want nil", got) - } - sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -756,10 +663,7 @@ func TestClient_ProvisionerKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { expected := "/provisioners/" + tt.kid + "/encrypted-key" @@ -770,27 +674,20 @@ func TestClient_ProvisionerKey(t *testing.T) { }) got, err := c.ProvisionerKey(tt.kid) - if (err != nil) != tt.wantErr { - t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.ProvisionerKey() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -798,7 +695,7 @@ func TestClient_ProvisionerKey(t *testing.T) { func TestClient_Roots(t *testing.T) { ok := &api.RootsResponse{ Certificates: []api.Certificate{ - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -820,37 +717,27 @@ func TestClient_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Roots() = %v, want nil", got) - } - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Roots() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -858,7 +745,7 @@ func TestClient_Roots(t *testing.T) { func TestClient_Federation(t *testing.T) { ok := &api.FederationResponse{ Certificates: []api.Certificate{ - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -879,46 +766,34 @@ func TestClient_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Federation() = %v, want nil", got) - } - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Federation() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_SSHRoots(t *testing.T) { - key, err := ssh.NewPublicKey(mustKey().Public()) - if err != nil { - t.Fatal(err) - } + key, err := ssh.NewPublicKey(mustKey(t).Public()) + require.NoError(t, err) ok := &api.SSHRootsResponse{ HostKeys: []api.SSHPublicKey{{PublicKey: key}}, @@ -942,37 +817,27 @@ func TestClient_SSHRoots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.SSHKeys() = %v, want nil", got) - } - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -1004,13 +869,14 @@ func Test_parseEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseEndpoint(tt.args.endpoint) - if (err != nil) != tt.wantErr { - t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseEndpoint() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -1043,24 +909,21 @@ func TestClient_RootFingerprint(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tr := tt.server.Client().Transport c, err := NewClient(tt.server.URL, WithTransport(tr)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -1069,12 +932,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() - client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) - sassert.FatalError(t, err) + caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) + require.NoError(t, err) - fp, err := client.RootFingerprint() - sassert.FatalError(t, err) - sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) + fp, err := caClient.RootFingerprint() + assert.NoError(t, err) + assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } func TestClient_SSHBastion(t *testing.T) { @@ -1104,39 +967,29 @@ func TestClient_SSHBastion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr) - return - } - - switch { - case err != nil: - if got != nil { - t.Errorf("Client.SSHBastion() = %v, want nil", got) - } - if tt.responseCode != 200 { - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) + if tt.wantErr { + if assert.Error(t, err) { + if tt.responseCode != 200 { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - } - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response) } + assert.Nil(t, got) + return } + + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -1155,13 +1008,10 @@ func TestClient_GetCaURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } - if got := c.GetCaURL(); got != tt.want { - t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) - } + require.NoError(t, err) + + got := c.GetCaURL() + assert.Equal(t, tt.want, got) }) } } @@ -1171,7 +1021,7 @@ func Test_enforceRequestID(t *testing.T) { set.Header.Set("X-Request-Id", "already-set") inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) - new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) tests := []struct { name string @@ -1190,7 +1040,7 @@ func Test_enforceRequestID(t *testing.T) { }, { name: "new", - r: new, + r: newRequestID, }, } for _, tt := range tests { diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 7dea3dc8..c29947ad 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -130,7 +130,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { //nolint:gosec // test tls config func TestAddRootCA(t *testing.T) { - cert := parseCertificate(rootPEM) + cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) @@ -163,7 +163,7 @@ func TestAddRootCA(t *testing.T) { //nolint:gosec // test tls config func TestAddClientCA(t *testing.T) { - cert := parseCertificate(rootPEM) + cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) @@ -214,7 +214,7 @@ func TestAddRootsToRootCAs(t *testing.T) { t.Fatal(err) } - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -269,7 +269,7 @@ func TestAddRootsToClientCAs(t *testing.T) { t.Fatal(err) } - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -329,8 +329,8 @@ func TestAddFederationToRootCAs(t *testing.T) { t.Fatal(err) } - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) @@ -394,8 +394,8 @@ func TestAddFederationToClientCAs(t *testing.T) { t.Fatal(err) } - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) @@ -454,7 +454,7 @@ func TestAddRootsToCAs(t *testing.T) { t.Fatal(err) } - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -514,8 +514,8 @@ func TestAddFederationToCAs(t *testing.T) { t.Fatal(err) } - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) diff --git a/ca/tls_test.go b/ca/tls_test.go index dbcc6023..a19685ce 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -401,13 +401,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { } func TestCertificate(t *testing.T) { - cert := parseCertificate(certPEM) + cert := parseCertificate(t, certPEM) ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: cert}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: cert}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { @@ -434,12 +434,12 @@ func TestCertificate(t *testing.T) { } func TestIntermediateCertificate(t *testing.T) { - intermediate := parseCertificate(rootPEM) + intermediate := parseCertificate(t, rootPEM) ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: intermediate}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(t, certPEM)}, {Certificate: intermediate}, }, } @@ -467,24 +467,24 @@ func TestIntermediateCertificate(t *testing.T) { } func TestRootCertificateCertificate(t *testing.T) { - root := parseCertificate(rootPEM) + root := parseCertificate(t, rootPEM) ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ {root, root}, }}, } noTLS := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct {