@ -9,24 +9,26 @@ import (
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/google/uuid"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
)
const (
@ -106,52 +108,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
-- -- - END CERTIFICATE REQUEST -- -- - `
)
func mustKey ( ) * ecdsa . PrivateKey {
func mustKey ( t * testing . T ) * ecdsa . PrivateKey {
t . Helper ( )
priv , err := ecdsa . GenerateKey ( elliptic . P256 ( ) , rand . Reader )
if err != nil {
panic ( err )
}
require . NoError ( t , err )
return priv
}
func parseCertificate ( data string ) * x509 . Certificate {
func parseCertificate ( t * testing . T , data string ) * x509 . Certificate {
t . Helper ( )
block , _ := pem . Decode ( [ ] byte ( data ) )
if block == nil {
panic ( "failed to parse certificate PEM" )
require . Fail ( t , "failed to parse certificate PEM" )
return nil
}
cert , err := x509 . ParseCertificate ( block . Bytes )
if err != nil {
panic ( "failed to parse certificate: " + err . Error ( ) )
}
require . NoError ( t , err , "failed to parse certificate" )
return cert
}
func parseCertificateRequest ( string ) * x509 . CertificateRequest {
func parseCertificateRequest ( t * testing . T , csrPEM string ) * x509 . CertificateRequest {
t . Helper ( )
block , _ := pem . Decode ( [ ] byte ( csrPEM ) )
if block == nil {
panic ( "failed to parse certificate request PEM" )
require . Fail ( t , "failed to parse certificate request PEM" )
return nil
}
csr , err := x509 . ParseCertificateRequest ( block . Bytes )
if err != nil {
panic ( "failed to parse certificate request: " + err . Error ( ) )
}
require . NoError ( t , err , "failed to parse certificate request" )
return csr
}
func equalJSON ( t * testing . T , a , b interface { } ) bool {
t . Helper ( )
if reflect . DeepEqual ( a , b ) {
return true
}
ab , err := json . Marshal ( a )
if err != nil {
t . Error ( err )
return false
}
require . NoError ( t , err )
bb , err := json . Marshal ( b )
if err != nil {
t . Error ( err )
return false
}
require . NoError ( t , err )
return bytes . Equal ( ab , bb )
}
@ -176,32 +175,23 @@ func TestClient_Version(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . Version ( )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Client.Version() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
assert . EqualError ( t , err , tt . expectedErr . Error ( ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Version() = %v, want nil" , got )
}
assert . HasPrefix ( t , tt . expectedErr . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Version() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -226,40 +216,30 @@ func TestClient_Health(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . Health ( )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Health() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
assert . EqualError ( t , err , tt . expectedErr . Error ( ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Health() = %v, want nil" , got )
}
assert . HasPrefix ( t , tt . expectedErr . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Health() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
func TestClient_Root ( t * testing . T ) {
ok := & api . RootResponse {
RootPEM : api . Certificate { Certificate : parseCertificate ( rootPEM) } ,
RootPEM : api . Certificate { Certificate : parseCertificate ( t , rootPEM ) } ,
}
tests := [ ] struct {
@ -280,10 +260,7 @@ func TestClient_Root(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
expected := "/root/" + tt . shasum
@ -294,37 +271,31 @@ func TestClient_Root(t *testing.T) {
} )
got , err := c . Root ( tt . shasum )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Client.Root() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
assert . EqualError ( t , err , tt . expectedErr . Error ( ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Root() = %v, want nil" , got )
}
assert . HasPrefix ( t , tt . expectedErr . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Root() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
func TestClient_Sign ( t * testing . T ) {
ok := & api . SignResponse {
ServerPEM : api . Certificate { Certificate : parseCertificate ( certPEM) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( rootPEM) } ,
ServerPEM : api . Certificate { Certificate : parseCertificate ( t , certPEM ) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( t , rootPEM ) } ,
CertChainPEM : [ ] api . Certificate {
{ Certificate : parseCertificate ( certPEM) } ,
{ Certificate : parseCertificate ( rootPEM) } ,
{ Certificate : parseCertificate ( t, certPEM) } ,
{ Certificate : parseCertificate ( t, rootPEM) } ,
} ,
}
request := & api . SignRequest {
CsrPEM : api . CertificateRequest { CertificateRequest : parseCertificateRequest ( csrPEM) } ,
CsrPEM : api . CertificateRequest { CertificateRequest : parseCertificateRequest ( t, csrPEM) } ,
OTT : "the-ott" ,
NotBefore : api . NewTimeDuration ( time . Now ( ) ) ,
NotAfter : api . NewTimeDuration ( time . Now ( ) . AddDate ( 0 , 1 , 0 ) ) ,
@ -350,16 +321,13 @@ func TestClient_Sign(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
body := new ( api . SignRequest )
if err := read . JSON ( req . Body , body ) ; err != nil {
e , ok := tt . response . ( error )
assert. Fatal ( t , ok , "response expected to be error type" )
require. True ( t , ok , "response expected to be error type" )
render . Error ( w , e )
return
} else if ! equalJSON ( t , body , tt . request ) {
@ -375,23 +343,16 @@ func TestClient_Sign(t *testing.T) {
} )
got , err := c . Sign ( tt . request )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Sign() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
assert . EqualError ( t , err , tt . expectedErr . Error ( ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Sign() = %v, want nil" , got )
}
assert . HasPrefix ( t , tt . expectedErr . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Sign() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -422,16 +383,13 @@ func TestClient_Revoke(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
body := new ( api . RevokeRequest )
if err := read . JSON ( req . Body , body ) ; err != nil {
e , ok := tt . response . ( error )
assert. Fatal ( t , ok , "response expected to be error type" )
require. True ( t , ok , "response expected to be error type" )
render . Error ( w , e )
return
} else if ! equalJSON ( t , body , tt . request ) {
@ -447,34 +405,27 @@ func TestClient_Revoke(t *testing.T) {
} )
got , err := c . Revoke ( tt . request , nil )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Revoke() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . expectedErr . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Revoke() = %v, want nil" , got )
}
assert . HasPrefix ( t , err . Error ( ) , tt . expectedErr . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Revoke() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
func TestClient_Renew ( t * testing . T ) {
ok := & api . SignResponse {
ServerPEM : api . Certificate { Certificate : parseCertificate ( certPEM) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( rootPEM) } ,
ServerPEM : api . Certificate { Certificate : parseCertificate ( t , certPEM ) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( t , rootPEM ) } ,
CertChainPEM : [ ] api . Certificate {
{ Certificate : parseCertificate ( certPEM) } ,
{ Certificate : parseCertificate ( rootPEM) } ,
{ Certificate : parseCertificate ( t, certPEM) } ,
{ Certificate : parseCertificate ( t, rootPEM) } ,
} ,
}
@ -497,49 +448,38 @@ func TestClient_Renew(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . Renew ( nil )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Renew() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Renew() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , err . Error ( ) , tt . err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Renew() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
func TestClient_RenewWithToken ( t * testing . T ) {
ok := & api . SignResponse {
ServerPEM : api . Certificate { Certificate : parseCertificate ( certPEM) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( rootPEM) } ,
ServerPEM : api . Certificate { Certificate : parseCertificate ( t , certPEM ) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( t , rootPEM ) } ,
CertChainPEM : [ ] api . Certificate {
{ Certificate : parseCertificate ( certPEM) } ,
{ Certificate : parseCertificate ( rootPEM) } ,
{ Certificate : parseCertificate ( t , certPEM ) } ,
{ Certificate : parseCertificate ( t , rootPEM ) } ,
} ,
}
@ -562,10 +502,7 @@ func TestClient_RenewWithToken(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
if req . Header . Get ( "Authorization" ) != "Bearer token" {
@ -576,44 +513,36 @@ func TestClient_RenewWithToken(t *testing.T) {
} )
got , err := c . RenewWithToken ( "token" )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.RenewWithToken() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.RenewWithToken() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , err . Error ( ) , tt . err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.RenewWithToken() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
func TestClient_Rekey ( t * testing . T ) {
ok := & api . SignResponse {
ServerPEM : api . Certificate { Certificate : parseCertificate ( certPEM) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( rootPEM) } ,
ServerPEM : api . Certificate { Certificate : parseCertificate ( t , certPEM ) } ,
CaPEM : api . Certificate { Certificate : parseCertificate ( t , rootPEM ) } ,
CertChainPEM : [ ] api . Certificate {
{ Certificate : parseCertificate ( certPEM) } ,
{ Certificate : parseCertificate ( rootPEM) } ,
{ Certificate : parseCertificate ( t , certPEM ) } ,
{ Certificate : parseCertificate ( t , rootPEM ) } ,
} ,
}
request := & api . RekeyRequest {
CsrPEM : api . CertificateRequest { CertificateRequest : parseCertificateRequest ( csrPEM) } ,
CsrPEM : api . CertificateRequest { CertificateRequest : parseCertificateRequest ( t, csrPEM) } ,
}
tests := [ ] struct {
@ -636,38 +565,27 @@ func TestClient_Rekey(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . Rekey ( tt . request , nil )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Renew() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Renew() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , err . Error ( ) , tt . err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Renew() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -699,10 +617,7 @@ func TestClient_Provisioners(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
if req . RequestURI != tt . expectedURI {
@ -712,22 +627,16 @@ func TestClient_Provisioners(t *testing.T) {
} )
got , err := c . Provisioners ( tt . args ... )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Client.Provisioners() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
assert . True ( t , strings . HasPrefix ( err . Error ( ) , errs . InternalServerErrorDefaultMsg ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Provisioners() = %v, want nil" , got )
}
assert . HasPrefix ( t , errs . InternalServerErrorDefaultMsg , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Provisioners() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -755,10 +664,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
expected := "/provisioners/" + tt . kid + "/encrypted-key"
@ -769,27 +675,20 @@ func TestClient_ProvisionerKey(t *testing.T) {
} )
got , err := c . ProvisionerKey ( tt . kid )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Client.ProvisionerKey() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.ProvisionerKey() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , tt . err . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.ProvisionerKey() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -797,7 +696,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
func TestClient_Roots ( t * testing . T ) {
ok := & api . RootsResponse {
Certificates : [ ] api . Certificate {
{ Certificate : parseCertificate ( rootPEM) } ,
{ Certificate : parseCertificate ( t, rootPEM) } ,
} ,
}
@ -819,37 +718,27 @@ func TestClient_Roots(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . Roots ( )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Roots() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Roots() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , err . Error ( ) , tt . err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Roots() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -857,7 +746,7 @@ func TestClient_Roots(t *testing.T) {
func TestClient_Federation ( t * testing . T ) {
ok := & api . FederationResponse {
Certificates : [ ] api . Certificate {
{ Certificate : parseCertificate ( rootPEM) } ,
{ Certificate : parseCertificate ( t, rootPEM) } ,
} ,
}
@ -878,46 +767,34 @@ func TestClient_Federation(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . Federation ( )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.Federation() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.Federation() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , tt . err . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.Federation() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
func TestClient_SSHRoots ( t * testing . T ) {
key , err := ssh . NewPublicKey ( mustKey ( ) . Public ( ) )
if err != nil {
t . Fatal ( err )
}
key , err := ssh . NewPublicKey ( mustKey ( t ) . Public ( ) )
require . NoError ( t , err )
ok := & api . SSHRootsResponse {
HostKeys : [ ] api . SSHPublicKey { { PublicKey : key } } ,
@ -941,37 +818,27 @@ func TestClient_SSHRoots(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . SSHRoots ( )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.SSHKeys() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
if assert . Error ( t , err ) {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . Nil ( t , got )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.SSHKeys() = %v, want nil" , got )
}
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
}
assert . HasPrefix ( t , tt . err . Error ( ) , err . Error ( ) )
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.SSHKeys() = %v, want %v" , got , tt . response )
}
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -1003,13 +870,14 @@ func Test_parseEndpoint(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got , err := parseEndpoint ( tt . args . endpoint )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "parseEndpoint() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
assert . Error ( t , err )
assert . Nil ( t , got )
return
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "parseEndpoint() = %v, want %v" , got , tt . want )
}
assert . NoError ( t , err )
assert . Equal ( t , tt . want , got )
} )
}
}
@ -1042,24 +910,21 @@ func TestClient_RootFingerprint(t *testing.T) {
t . Run ( tt . name , func ( t * testing . T ) {
tr := tt . server . Client ( ) . Transport
c , err := NewClient ( tt . server . URL , WithTransport ( tr ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
tt . server . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . RootFingerprint ( )
if ( err != nil ) != tt . wantErr {
fmt. Printf ( "%+v" , err )
t. Errorf ( "Client.RootFingerprint() error = %v, wantErr %v" , err , tt . wantErr )
if tt . wantErr {
assert. Error ( t , err )
assert. Empty ( t , got )
return
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "Client.RootFingerprint() = %v, want %v" , got , tt . want )
}
assert . NoError ( t , err )
assert . Equal ( t , tt . want , got )
} )
}
}
@ -1068,12 +933,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) {
srv := startCABootstrapServer ( )
defer srv . Close ( )
c lient, err := NewClient ( srv . URL + "/sign" , WithRootFile ( "testdata/secrets/root_ca.crt" ) )
assert. Fatal Error( t , err )
c aC lient, err := NewClient ( srv . URL + "/sign" , WithRootFile ( "testdata/secrets/root_ca.crt" ) )
require. No Error( t , err )
fp , err := c lient. RootFingerprint ( )
assert . Fatal Error( t , err )
assert . Equal s ( t , "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" , fp )
fp , err := c aC lient. RootFingerprint ( )
assert . No Error( t , err )
assert . Equal ( t , "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" , fp )
}
func TestClient_SSHBastion ( t * testing . T ) {
@ -1103,39 +968,29 @@ func TestClient_SSHBastion(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( srv . URL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
require . NoError ( t , err )
srv . Config . Handler = http . HandlerFunc ( func ( w http . ResponseWriter , req * http . Request ) {
render . JSONStatus ( w , tt . response , tt . responseCode )
} )
got , err := c . SSHBastion ( tt . request )
if ( err != nil ) != tt . wantErr {
fmt . Printf ( "%+v" , err )
t . Errorf ( "Client.SSHBastion() error = %v, wantErr %v" , err , tt . wantErr )
return
}
switch {
case err != nil :
if got != nil {
t . Errorf ( "Client.SSHBastion() = %v, want nil" , got )
}
if tt . responseCode != 200 {
var sc render . StatusCodedError
if assert . True ( t , errors . As ( err , & sc ) , "error does not implement StatusCodedError interface" ) {
assert . Equals ( t , sc . StatusCode ( ) , tt . responseCode )
if tt . wantErr {
if assert . Error ( t , err ) {
if tt . responseCode != 200 {
var sc render . StatusCodedError
if assert . ErrorAs ( t , err , & sc ) {
assert . Equal ( t , tt . responseCode , sc . StatusCode ( ) )
}
assert . True ( t , strings . HasPrefix ( err . Error ( ) , tt . err . Error ( ) ) )
}
assert . HasPrefix ( t , err . Error ( ) , tt . err . Error ( ) )
}
default :
if ! reflect . DeepEqual ( got , tt . response ) {
t . Errorf ( "Client.SSHBastion() = %v, want %v" , got , tt . response )
}
assert . Nil ( t , got )
return
}
assert . NoError ( t , err )
assert . Equal ( t , tt . response , got )
} )
}
}
@ -1154,13 +1009,60 @@ func TestClient_GetCaURL(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c , err := NewClient ( tt . caURL , WithTransport ( http . DefaultTransport ) )
if err != nil {
t . Errorf ( "NewClient() error = %v" , err )
return
}
if got := c . GetCaURL ( ) ; got != tt . want {
t . Errorf ( "Client.GetCaURL() = %v, want %v" , got , tt . want )
require . NoError ( t , err )
got := c . GetCaURL ( )
assert . Equal ( t , tt . want , got )
} )
}
}
func Test_enforceRequestID ( t * testing . T ) {
set := httptest . NewRequest ( http . MethodGet , "https://example.com" , http . NoBody )
set . Header . Set ( "X-Request-Id" , "already-set" )
inContext := httptest . NewRequest ( http . MethodGet , "https://example.com" , http . NoBody )
inContext = inContext . WithContext ( client . NewRequestIDContext ( inContext . Context ( ) , "from-context" ) )
newRequestID := httptest . NewRequest ( http . MethodGet , "https://example.com" , http . NoBody )
tests := [ ] struct {
name string
r * http . Request
want string
} {
{
name : "set" ,
r : set ,
want : "already-set" ,
} ,
{
name : "context" ,
r : inContext ,
want : "from-context" ,
} ,
{
name : "new" ,
r : newRequestID ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
enforceRequestID ( tt . r )
v := tt . r . Header . Get ( "X-Request-Id" )
if assert . NotEmpty ( t , v ) {
if tt . want != "" {
assert . Equal ( t , tt . want , v )
}
}
} )
}
}
func Test_newRequestID ( t * testing . T ) {
requestID := newRequestID ( )
u , err := uuid . Parse ( requestID )
assert . NoError ( t , err )
assert . Equal ( t , uuid . Version ( 0x4 ) , u . Version ( ) )
assert . Equal ( t , uuid . RFC4122 , u . Variant ( ) )
assert . Equal ( t , requestID , u . String ( ) )
}