|
|
|
@ -17,16 +17,17 @@ import (
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"go.step.sm/crypto/x509util"
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/assert"
|
|
|
|
|
sassert "github.com/smallstep/assert"
|
|
|
|
|
"github.com/smallstep/certificates/api"
|
|
|
|
|
"github.com/smallstep/certificates/api/read"
|
|
|
|
|
"github.com/smallstep/certificates/api/render"
|
|
|
|
|
"github.com/smallstep/certificates/authority"
|
|
|
|
|
"github.com/smallstep/certificates/authority/provisioner"
|
|
|
|
|
"github.com/smallstep/certificates/ca/client"
|
|
|
|
|
"github.com/smallstep/certificates/errs"
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"go.step.sm/crypto/x509util"
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
@ -196,7 +197,7 @@ func TestClient_Version(t *testing.T) {
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Version() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
|
|
|
|
@ -247,7 +248,7 @@ func TestClient_Health(t *testing.T) {
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Health() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
|
|
|
|
@ -304,7 +305,7 @@ func TestClient_Root(t *testing.T) {
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Root() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
|
|
|
|
@ -359,7 +360,7 @@ func TestClient_Sign(t *testing.T) {
|
|
|
|
|
body := new(api.SignRequest)
|
|
|
|
|
if err := read.JSON(req.Body, body); err != nil {
|
|
|
|
|
e, ok := tt.response.(error)
|
|
|
|
|
assert.Fatal(t, ok, "response expected to be error type")
|
|
|
|
|
sassert.Fatal(t, ok, "response expected to be error type")
|
|
|
|
|
render.Error(w, e)
|
|
|
|
|
return
|
|
|
|
|
} else if !equalJSON(t, body, tt.request) {
|
|
|
|
@ -386,7 +387,7 @@ func TestClient_Sign(t *testing.T) {
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Sign() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
|
|
|
|
@ -431,7 +432,7 @@ func TestClient_Revoke(t *testing.T) {
|
|
|
|
|
body := new(api.RevokeRequest)
|
|
|
|
|
if err := read.JSON(req.Body, body); err != nil {
|
|
|
|
|
e, ok := tt.response.(error)
|
|
|
|
|
assert.Fatal(t, ok, "response expected to be error type")
|
|
|
|
|
sassert.Fatal(t, ok, "response expected to be error type")
|
|
|
|
|
render.Error(w, e)
|
|
|
|
|
return
|
|
|
|
|
} else if !equalJSON(t, body, tt.request) {
|
|
|
|
@ -458,7 +459,7 @@ func TestClient_Revoke(t *testing.T) {
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Revoke() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
|
|
|
|
|
sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
|
|
|
|
@ -520,10 +521,10 @@ func TestClient_Renew(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
sassert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
|
|
|
|
@ -589,10 +590,10 @@ func TestClient_RenewWithToken(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
sassert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
|
|
|
|
@ -659,10 +660,10 @@ func TestClient_Rekey(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
sassert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
|
|
|
|
@ -722,7 +723,7 @@ func TestClient_Provisioners(t *testing.T) {
|
|
|
|
|
if got != nil {
|
|
|
|
|
t.Errorf("Client.Provisioners() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
|
|
|
|
|
sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
|
|
|
|
@ -781,10 +782,10 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
|
|
|
|
@ -841,10 +842,10 @@ func TestClient_Roots(t *testing.T) {
|
|
|
|
|
t.Errorf("Client.Roots() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
sassert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
|
|
|
|
@ -900,10 +901,10 @@ func TestClient_Federation(t *testing.T) {
|
|
|
|
|
t.Errorf("Client.Federation() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
|
|
|
|
@ -963,10 +964,10 @@ func TestClient_SSHRoots(t *testing.T) {
|
|
|
|
|
t.Errorf("Client.SSHKeys() = %v, want nil", got)
|
|
|
|
|
}
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
sassert.HasPrefix(t, tt.err.Error(), err.Error())
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
|
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
|
|
|
|
@ -1069,11 +1070,11 @@ func TestClient_RootFingerprintWithServer(t *testing.T) {
|
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
|
|
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
sassert.FatalError(t, err)
|
|
|
|
|
|
|
|
|
|
fp, err := client.RootFingerprint()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
|
|
|
|
|
sassert.FatalError(t, err)
|
|
|
|
|
sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestClient_SSHBastion(t *testing.T) {
|
|
|
|
@ -1126,10 +1127,10 @@ func TestClient_SSHBastion(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
if tt.responseCode != 200 {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
|
|
|
|
|
sassert.Equals(t, sc.StatusCode(), tt.responseCode)
|
|
|
|
|
}
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
sassert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
if !reflect.DeepEqual(got, tt.response) {
|
|
|
|
@ -1164,3 +1165,44 @@ func TestClient_GetCaURL(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_enforceRequestID(t *testing.T) {
|
|
|
|
|
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
|
|
|
|
set.Header.Set("X-Request-Id", "already-set")
|
|
|
|
|
inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
|
|
|
|
inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context"))
|
|
|
|
|
new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
r *http.Request
|
|
|
|
|
want string
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "set",
|
|
|
|
|
r: set,
|
|
|
|
|
want: "already-set",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "context",
|
|
|
|
|
r: inContext,
|
|
|
|
|
want: "from-context",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "new",
|
|
|
|
|
r: new,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
enforceRequestID(tt.r)
|
|
|
|
|
|
|
|
|
|
v := tt.r.Header.Get("X-Request-Id")
|
|
|
|
|
if assert.NotEmpty(t, v) {
|
|
|
|
|
if tt.want != "" {
|
|
|
|
|
assert.Equal(t, tt.want, v)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|