Cleanup CA client tests by removing `smallstep/assert`

pull/1743/head
Herman Slatman 3 months ago
parent 532b9df0a3
commit b9d6bfc1eb
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -9,15 +9,14 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
sassert "github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
@ -26,6 +25,7 @@ import (
"github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -107,52 +107,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
-----END CERTIFICATE REQUEST-----` -----END CERTIFICATE REQUEST-----`
) )
func mustKey() *ecdsa.PrivateKey { func mustKey(t *testing.T) *ecdsa.PrivateKey {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { require.NoError(t, err)
panic(err)
}
return priv return priv
} }
func parseCertificate(data string) *x509.Certificate { func parseCertificate(t *testing.T, data string) *x509.Certificate {
t.Helper()
block, _ := pem.Decode([]byte(data)) block, _ := pem.Decode([]byte(data))
if block == nil { if block == nil {
panic("failed to parse certificate PEM") require.Fail(t, "failed to parse certificate PEM")
return nil
} }
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { require.NoError(t, err, "failed to parse certificate")
panic("failed to parse certificate: " + err.Error())
}
return cert return cert
} }
func parseCertificateRequest(string) *x509.CertificateRequest { func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest {
t.Helper()
block, _ := pem.Decode([]byte(csrPEM)) block, _ := pem.Decode([]byte(csrPEM))
if block == nil { if block == nil {
panic("failed to parse certificate request PEM") require.Fail(t, "failed to parse certificate request PEM")
return nil
} }
csr, err := x509.ParseCertificateRequest(block.Bytes) csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil { require.NoError(t, err, "failed to parse certificate request")
panic("failed to parse certificate request: " + err.Error())
}
return csr return csr
} }
func equalJSON(t *testing.T, a, b interface{}) bool { func equalJSON(t *testing.T, a, b interface{}) bool {
t.Helper()
if reflect.DeepEqual(a, b) { if reflect.DeepEqual(a, b) {
return true return true
} }
ab, err := json.Marshal(a) ab, err := json.Marshal(a)
if err != nil { require.NoError(t, err)
t.Error(err)
return false
}
bb, err := json.Marshal(b) bb, err := json.Marshal(b)
if err != nil { require.NoError(t, err)
t.Error(err)
return false
}
return bytes.Equal(ab, bb) return bytes.Equal(ab, bb)
} }
@ -177,32 +174,23 @@ func TestClient_Version(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Version() got, err := c.Version()
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Version() = %v, want nil", got)
}
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -227,40 +215,30 @@ func TestClient_Health(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Health() got, err := c.Health()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr) assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Health() = %v, want nil", got)
}
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Root(t *testing.T) { func TestClient_Root(t *testing.T) {
ok := &api.RootResponse{ ok := &api.RootResponse{
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
} }
tests := []struct { tests := []struct {
@ -281,10 +259,7 @@ func TestClient_Root(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
expected := "/root/" + tt.shasum expected := "/root/" + tt.shasum
@ -295,37 +270,31 @@ func TestClient_Root(t *testing.T) {
}) })
got, err := c.Root(tt.shasum) got, err := c.Root(tt.shasum)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Root() = %v, want nil", got)
}
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Sign(t *testing.T) { func TestClient_Sign(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
request := &api.SignRequest{ request := &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)},
OTT: "the-ott", OTT: "the-ott",
NotBefore: api.NewTimeDuration(time.Now()), NotBefore: api.NewTimeDuration(time.Now()),
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
@ -351,16 +320,13 @@ func TestClient_Sign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
sassert.Fatal(t, ok, "response expected to be error type") require.True(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
@ -376,23 +342,16 @@ func TestClient_Sign(t *testing.T) {
}) })
got, err := c.Sign(tt.request) got, err := c.Sign(tt.request)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr) assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got)
}
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -423,16 +382,13 @@ func TestClient_Revoke(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
sassert.Fatal(t, ok, "response expected to be error type") require.True(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
@ -448,34 +404,27 @@ func TestClient_Revoke(t *testing.T) {
}) })
got, err := c.Revoke(tt.request, nil) got, err := c.Revoke(tt.request, nil)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr) assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got)
}
sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Renew(t *testing.T) { func TestClient_Renew(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -498,49 +447,38 @@ func TestClient_Renew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Renew(nil) got, err := c.Renew(nil)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_RenewWithToken(t *testing.T) { func TestClient_RenewWithToken(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -563,10 +501,7 @@ func TestClient_RenewWithToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" { if req.Header.Get("Authorization") != "Bearer token" {
@ -577,44 +512,36 @@ func TestClient_RenewWithToken(t *testing.T) {
}) })
got, err := c.RenewWithToken("token") got, err := c.RenewWithToken("token")
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Rekey(t *testing.T) { func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
request := &api.RekeyRequest{ request := &api.RekeyRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)},
} }
tests := []struct { tests := []struct {
@ -637,38 +564,27 @@ func TestClient_Rekey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Rekey(tt.request, nil) got, err := c.Rekey(tt.request, nil)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -700,10 +616,7 @@ func TestClient_Provisioners(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != tt.expectedURI { if req.RequestURI != tt.expectedURI {
@ -713,22 +626,16 @@ func TestClient_Provisioners(t *testing.T) {
}) })
got, err := c.Provisioners(tt.args...) got, err := c.Provisioners(tt.args...)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Provisioners() = %v, want nil", got)
}
sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -756,10 +663,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
expected := "/provisioners/" + tt.kid + "/encrypted-key" expected := "/provisioners/" + tt.kid + "/encrypted-key"
@ -770,27 +674,20 @@ func TestClient_ProvisionerKey(t *testing.T) {
}) })
got, err := c.ProvisionerKey(tt.kid) got, err := c.ProvisionerKey(tt.kid)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -798,7 +695,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
func TestClient_Roots(t *testing.T) { func TestClient_Roots(t *testing.T) {
ok := &api.RootsResponse{ ok := &api.RootsResponse{
Certificates: []api.Certificate{ Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -820,37 +717,27 @@ func TestClient_Roots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Roots() got, err := c.Roots()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -858,7 +745,7 @@ func TestClient_Roots(t *testing.T) {
func TestClient_Federation(t *testing.T) { func TestClient_Federation(t *testing.T) {
ok := &api.FederationResponse{ ok := &api.FederationResponse{
Certificates: []api.Certificate{ Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -879,46 +766,34 @@ func TestClient_Federation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Federation() got, err := c.Federation()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_SSHRoots(t *testing.T) { func TestClient_SSHRoots(t *testing.T) {
key, err := ssh.NewPublicKey(mustKey().Public()) key, err := ssh.NewPublicKey(mustKey(t).Public())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ok := &api.SSHRootsResponse{ ok := &api.SSHRootsResponse{
HostKeys: []api.SSHPublicKey{{PublicKey: key}}, HostKeys: []api.SSHPublicKey{{PublicKey: key}},
@ -942,37 +817,27 @@ func TestClient_SSHRoots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHRoots() got, err := c.SSHRoots()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got)
}
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
}
sassert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -1004,13 +869,14 @@ func Test_parseEndpoint(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := parseEndpoint(tt.args.endpoint) got, err := parseEndpoint(tt.args.endpoint)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr) assert.Error(t, err)
assert.Nil(t, got)
return return
} }
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseEndpoint() = %v, want %v", got, tt.want) assert.NoError(t, err)
} assert.Equal(t, tt.want, got)
}) })
} }
} }
@ -1043,24 +909,21 @@ func TestClient_RootFingerprint(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tr := tt.server.Client().Transport tr := tt.server.Client().Transport
c, err := NewClient(tt.server.URL, WithTransport(tr)) c, err := NewClient(tt.server.URL, WithTransport(tr))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.RootFingerprint() got, err := c.RootFingerprint()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) assert.Error(t, err)
t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr) assert.Empty(t, got)
return return
} }
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want) assert.NoError(t, err)
} assert.Equal(t, tt.want, got)
}) })
} }
} }
@ -1069,12 +932,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) {
srv := startCABootstrapServer() srv := startCABootstrapServer()
defer srv.Close() defer srv.Close()
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
sassert.FatalError(t, err) require.NoError(t, err)
fp, err := client.RootFingerprint() fp, err := caClient.RootFingerprint()
sassert.FatalError(t, err) assert.NoError(t, err)
sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
} }
func TestClient_SSHBastion(t *testing.T) { func TestClient_SSHBastion(t *testing.T) {
@ -1104,39 +967,29 @@ func TestClient_SSHBastion(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHBastion(tt.request) got, err := c.SSHBastion(tt.request)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr) if tt.responseCode != 200 {
return var sc render.StatusCodedError
} if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
switch { }
case err != nil: assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
if got != nil {
t.Errorf("Client.SSHBastion() = %v, want nil", got)
}
if tt.responseCode != 200 {
var sc render.StatusCodedError
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
sassert.HasPrefix(t, err.Error(), tt.err.Error())
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response)
} }
assert.Nil(t, got)
return
} }
assert.NoError(t, err)
assert.Equal(t, tt.response, got)
}) })
} }
} }
@ -1155,13 +1008,10 @@ func TestClient_GetCaURL(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return got := c.GetCaURL()
} assert.Equal(t, tt.want, got)
if got := c.GetCaURL(); got != tt.want {
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
}
}) })
} }
} }
@ -1171,7 +1021,7 @@ func Test_enforceRequestID(t *testing.T) {
set.Header.Set("X-Request-Id", "already-set") set.Header.Set("X-Request-Id", "already-set")
inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context"))
new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
tests := []struct { tests := []struct {
name string name string
@ -1190,7 +1040,7 @@ func Test_enforceRequestID(t *testing.T) {
}, },
{ {
name: "new", name: "new",
r: new, r: newRequestID,
}, },
} }
for _, tt := range tests { for _, tt := range tests {

@ -130,7 +130,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootCA(t *testing.T) { func TestAddRootCA(t *testing.T) {
cert := parseCertificate(rootPEM) cert := parseCertificate(t, rootPEM)
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -163,7 +163,7 @@ func TestAddRootCA(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddClientCA(t *testing.T) { func TestAddClientCA(t *testing.T) {
cert := parseCertificate(rootPEM) cert := parseCertificate(t, rootPEM)
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -214,7 +214,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
cert := parseCertificate(string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -269,7 +269,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
cert := parseCertificate(string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -329,8 +329,8 @@ func TestAddFederationToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
crt1 := parseCertificate(string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(string(federated)) crt2 := parseCertificate(t, string(federated))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(crt1) pool.AddCert(crt1)
pool.AddCert(crt2) pool.AddCert(crt2)
@ -394,8 +394,8 @@ func TestAddFederationToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
crt1 := parseCertificate(string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(string(federated)) crt2 := parseCertificate(t, string(federated))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(crt1) pool.AddCert(crt1)
pool.AddCert(crt2) pool.AddCert(crt2)
@ -454,7 +454,7 @@ func TestAddRootsToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
cert := parseCertificate(string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -514,8 +514,8 @@ func TestAddFederationToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
crt1 := parseCertificate(string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(string(federated)) crt2 := parseCertificate(t, string(federated))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(crt1) pool.AddCert(crt1)
pool.AddCert(crt2) pool.AddCert(crt2)

@ -401,13 +401,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
func TestCertificate(t *testing.T) { func TestCertificate(t *testing.T) {
cert := parseCertificate(certPEM) cert := parseCertificate(t, certPEM)
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: cert}, ServerPEM: api.Certificate{Certificate: cert},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: cert}, {Certificate: cert},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
tests := []struct { tests := []struct {
@ -434,12 +434,12 @@ func TestCertificate(t *testing.T) {
} }
func TestIntermediateCertificate(t *testing.T) { func TestIntermediateCertificate(t *testing.T) {
intermediate := parseCertificate(rootPEM) intermediate := parseCertificate(t, rootPEM)
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: intermediate}, CaPEM: api.Certificate{Certificate: intermediate},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: intermediate}, {Certificate: intermediate},
}, },
} }
@ -467,24 +467,24 @@ func TestIntermediateCertificate(t *testing.T) {
} }
func TestRootCertificateCertificate(t *testing.T) { func TestRootCertificateCertificate(t *testing.T) {
root := parseCertificate(rootPEM) root := parseCertificate(t, rootPEM)
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{
{root, root}, {root, root},
}}, }},
} }
noTLS := &api.SignResponse{ noTLS := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
tests := []struct { tests := []struct {

Loading…
Cancel
Save