|
|
|
@ -9,24 +9,26 @@ import (
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"encoding/pem"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
"net/url"
|
|
|
|
|
"reflect"
|
|
|
|
|
"strings"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"go.step.sm/crypto/x509util"
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/assert"
|
|
|
|
|
"github.com/google/uuid"
|
|
|
|
|
"github.com/smallstep/certificates/api"
|
|
|
|
|
"github.com/smallstep/certificates/api/read"
|
|
|
|
|
"github.com/smallstep/certificates/api/render"
|
|
|
|
|
"github.com/smallstep/certificates/authority"
|
|
|
|
|
"github.com/smallstep/certificates/authority/provisioner"
|
|
|
|
|
"github.com/smallstep/certificates/ca/client"
|
|
|
|
|
"github.com/smallstep/certificates/errs"
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
"go.step.sm/crypto/x509util"
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
@ -106,52 +108,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
|
|
|
|
|
-----END CERTIFICATE REQUEST-----`
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func mustKey() *ecdsa.PrivateKey {
|
|
|
|
|
func mustKey(t *testing.T) *ecdsa.PrivateKey {
|
|
|
|
|
t.Helper()
|
|
|
|
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
return priv
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func parseCertificate(data string) *x509.Certificate {
|
|
|
|
|
func parseCertificate(t *testing.T, data string) *x509.Certificate {
|
|
|
|
|
t.Helper()
|
|
|
|
|
block, _ := pem.Decode([]byte(data))
|
|
|
|
|
if block == nil {
|
|
|
|
|
panic("failed to parse certificate PEM")
|
|
|
|
|
require.Fail(t, "failed to parse certificate PEM")
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic("failed to parse certificate: " + err.Error())
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err, "failed to parse certificate")
|
|
|
|
|
return cert
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func parseCertificateRequest(string) *x509.CertificateRequest {
|
|
|
|
|
func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest {
|
|
|
|
|
t.Helper()
|
|
|
|
|
block, _ := pem.Decode([]byte(csrPEM))
|
|
|
|
|
if block == nil {
|
|
|
|
|
panic("failed to parse certificate request PEM")
|
|
|
|
|
require.Fail(t, "failed to parse certificate request PEM")
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
csr, err := x509.ParseCertificateRequest(block.Bytes)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic("failed to parse certificate request: " + err.Error())
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err, "failed to parse certificate request")
|
|
|
|
|
return csr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func equalJSON(t *testing.T, a, b interface{}) bool {
|
|
|
|
|
t.Helper()
|
|
|
|
|
if reflect.DeepEqual(a, b) {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ab, err := json.Marshal(a)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Error(err)
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
bb, err := json.Marshal(b)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Error(err)
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
return bytes.Equal(ab, bb)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -176,32 +175,23 @@ func TestClient_Version(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Version()
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
assert.EqualError(t, err, tt.expectedErr.Error())
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Version() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -226,40 +216,30 @@ func TestClient_Health(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Health()
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
assert.EqualError(t, err, tt.expectedErr.Error())
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Health() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_Root(t *testing.T) {
|
|
|
|
|
ok := &api.RootResponse{
|
|
|
|
|
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
@ -280,10 +260,7 @@ func TestClient_Root(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
expected := "/root/" + tt.shasum
|
|
|
|
@ -294,37 +271,31 @@ func TestClient_Root(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Root(tt.shasum)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
assert.EqualError(t, err, tt.expectedErr.Error())
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Root() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_Sign(t *testing.T) {
|
|
|
|
|
ok := &api.SignResponse{
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
CertChainPEM: []api.Certificate{
|
|
|
|
|
{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
request := &api.SignRequest{
|
|
|
|
|
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)},
|
|
|
|
|
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)},
|
|
|
|
|
OTT: "the-ott",
|
|
|
|
|
NotBefore: api.NewTimeDuration(time.Now()),
|
|
|
|
|
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
|
|
|
|
@ -350,16 +321,13 @@ func TestClient_Sign(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
body := new(api.SignRequest)
|
|
|
|
|
if err := read.JSON(req.Body, body); err != nil {
|
|
|
|
|
e, ok := tt.response.(error)
|
|
|
|
|
assert.Fatal(t, ok, "response expected to be error type")
|
|
|
|
|
require.True(t, ok, "response expected to be error type")
|
|
|
|
|
render.Error(w, e)
|
|
|
|
|
return
|
|
|
|
|
} else if !equalJSON(t, body, tt.request) {
|
|
|
|
@ -375,23 +343,16 @@ func TestClient_Sign(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Sign(tt.request)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
assert.EqualError(t, err, tt.expectedErr.Error())
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Sign() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -422,16 +383,13 @@ func TestClient_Revoke(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
body := new(api.RevokeRequest)
|
|
|
|
|
if err := read.JSON(req.Body, body); err != nil {
|
|
|
|
|
e, ok := tt.response.(error)
|
|
|
|
|
assert.Fatal(t, ok, "response expected to be error type")
|
|
|
|
|
require.True(t, ok, "response expected to be error type")
|
|
|
|
|
render.Error(w, e)
|
|
|
|
|
return
|
|
|
|
|
} else if !equalJSON(t, body, tt.request) {
|
|
|
|
@ -447,34 +405,27 @@ func TestClient_Revoke(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Revoke(tt.request, nil)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Revoke() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_Renew(t *testing.T) {
|
|
|
|
|
ok := &api.SignResponse{
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
CertChainPEM: []api.Certificate{
|
|
|
|
|
{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -497,49 +448,38 @@ func TestClient_Renew(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Renew(nil)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Renew() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_RenewWithToken(t *testing.T) {
|
|
|
|
|
ok := &api.SignResponse{
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
CertChainPEM: []api.Certificate{
|
|
|
|
|
{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -562,10 +502,7 @@ func TestClient_RenewWithToken(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
if req.Header.Get("Authorization") != "Bearer token" {
|
|
|
|
@ -576,44 +513,36 @@ func TestClient_RenewWithToken(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.RenewWithToken("token")
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_Rekey(t *testing.T) {
|
|
|
|
|
ok := &api.SignResponse{
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
CertChainPEM: []api.Certificate{
|
|
|
|
|
{Certificate: parseCertificate(certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, certPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
request := &api.RekeyRequest{
|
|
|
|
|
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)},
|
|
|
|
|
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
@ -636,38 +565,27 @@ func TestClient_Rekey(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Rekey(tt.request, nil)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Renew() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -699,10 +617,7 @@ func TestClient_Provisioners(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
if req.RequestURI != tt.expectedURI {
|
|
|
|
@ -712,22 +627,16 @@ func TestClient_Provisioners(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Provisioners(tt.args...)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Provisioners() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -755,10 +664,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
expected := "/provisioners/" + tt.kid + "/encrypted-key"
|
|
|
|
@ -769,27 +675,20 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.ProvisionerKey(tt.kid)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -797,7 +696,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|
|
|
|
func TestClient_Roots(t *testing.T) {
|
|
|
|
|
ok := &api.RootsResponse{
|
|
|
|
|
Certificates: []api.Certificate{
|
|
|
|
|
{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -819,37 +718,27 @@ func TestClient_Roots(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Roots()
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Roots() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -857,7 +746,7 @@ func TestClient_Roots(t *testing.T) {
|
|
|
|
|
func TestClient_Federation(t *testing.T) {
|
|
|
|
|
ok := &api.FederationResponse{
|
|
|
|
|
Certificates: []api.Certificate{
|
|
|
|
|
{Certificate: parseCertificate(rootPEM)},
|
|
|
|
|
{Certificate: parseCertificate(t, rootPEM)},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -878,46 +767,34 @@ func TestClient_Federation(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.Federation()
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Federation() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_SSHRoots(t *testing.T) {
|
|
|
|
|
key, err := ssh.NewPublicKey(mustKey().Public())
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
key, err := ssh.NewPublicKey(mustKey(t).Public())
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
ok := &api.SSHRootsResponse{
|
|
|
|
|
HostKeys: []api.SSHPublicKey{{PublicKey: key}},
|
|
|
|
@ -941,37 +818,27 @@ func TestClient_SSHRoots(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.SSHRoots()
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.SSHKeys() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1003,13 +870,14 @@ func Test_parseEndpoint(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
got, err := parseEndpoint(tt.args.endpoint)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
assert.Error(t, err)
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
|
|
|
t.Errorf("parseEndpoint() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.want, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1042,24 +910,21 @@ func TestClient_RootFingerprint(t *testing.T) {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
tr := tt.server.Client().Transport
|
|
|
|
|
c, err := NewClient(tt.server.URL, WithTransport(tr))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.RootFingerprint()
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
assert.Error(t, err)
|
|
|
|
|
assert.Empty(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
|
|
|
t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.want, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1068,12 +933,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) {
|
|
|
|
|
srv := startCABootstrapServer()
|
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
|
|
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
fp, err := client.RootFingerprint()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
|
|
|
|
|
fp, err := caClient.RootFingerprint()
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_SSHBastion(t *testing.T) {
|
|
|
|
@ -1103,39 +968,29 @@ func TestClient_SSHBastion(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
|
|
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
got, err := c.SSHBastion(tt.request)
|
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
|
|
|
fmt.Printf("%+v", err)
|
|
|
|
|
t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case err != nil:
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.SSHBastion() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
if tt.responseCode != 200 {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
if assert.Error(t, err) {
|
|
|
|
|
if tt.responseCode != 200 {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.ErrorAs(t, err, &sc) {
|
|
|
|
|
assert.Equal(t, tt.responseCode, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response)
|
|
|
|
|
}
|
|
|
|
|
assert.Nil(t, got)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, tt.response, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1154,13 +1009,60 @@ func TestClient_GetCaURL(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("NewClient() error = %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if got := c.GetCaURL(); got != tt.want {
|
|
|
|
|
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
got := c.GetCaURL()
|
|
|
|
|
assert.Equal(t, tt.want, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_enforceRequestID(t *testing.T) {
|
|
|
|
|
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
|
|
|
|
set.Header.Set("X-Request-Id", "already-set")
|
|
|
|
|
inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
|
|
|
|
inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context"))
|
|
|
|
|
newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
r *http.Request
|
|
|
|
|
want string
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "set",
|
|
|
|
|
r: set,
|
|
|
|
|
want: "already-set",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "context",
|
|
|
|
|
r: inContext,
|
|
|
|
|
want: "from-context",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "new",
|
|
|
|
|
r: newRequestID,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
enforceRequestID(tt.r)
|
|
|
|
|
|
|
|
|
|
v := tt.r.Header.Get("X-Request-Id")
|
|
|
|
|
if assert.NotEmpty(t, v) {
|
|
|
|
|
if tt.want != "" {
|
|
|
|
|
assert.Equal(t, tt.want, v)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_newRequestID(t *testing.T) {
|
|
|
|
|
requestID := newRequestID()
|
|
|
|
|
u, err := uuid.Parse(requestID)
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
assert.Equal(t, uuid.Version(0x4), u.Version())
|
|
|
|
|
assert.Equal(t, uuid.RFC4122, u.Variant())
|
|
|
|
|
assert.Equal(t, requestID, u.String())
|
|
|
|
|
}
|
|
|
|
|