2019-12-12 04:23:44 +00:00
package identity
import (
2019-12-12 20:23:53 +00:00
"crypto"
2019-12-12 04:23:44 +00:00
"crypto/tls"
2019-12-18 20:46:46 +00:00
"crypto/x509"
"fmt"
2019-12-12 20:23:53 +00:00
"io/ioutil"
2019-12-18 20:46:46 +00:00
"net/http"
2019-12-12 20:23:53 +00:00
"os"
"path/filepath"
2019-12-12 04:23:44 +00:00
"reflect"
"testing"
2019-12-12 20:23:53 +00:00
"github.com/smallstep/certificates/api"
2020-08-14 22:33:50 +00:00
"go.step.sm/crypto/pemutil"
2019-12-12 04:23:44 +00:00
)
func TestLoadDefaultIdentity ( t * testing . T ) {
oldFile := IdentityFile
defer func ( ) {
IdentityFile = oldFile
} ( )
expected := & Identity {
Type : "mTLS" ,
Certificate : "testdata/identity/identity.crt" ,
Key : "testdata/identity/identity_key" ,
}
tests := [ ] struct {
name string
prepare func ( )
want * Identity
wantErr bool
} {
2021-11-15 23:32:07 +00:00
{ "ok" , func ( ) { IdentityFile = returnInput ( "testdata/config/identity.json" ) } , expected , false } ,
{ "fail read" , func ( ) { IdentityFile = returnInput ( "testdata/config/missing.json" ) } , nil , true } ,
{ "fail unmarshal" , func ( ) { IdentityFile = returnInput ( "testdata/config/fail.json" ) } , nil , true } ,
2019-12-12 04:23:44 +00:00
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
tt . prepare ( )
got , err := LoadDefaultIdentity ( )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "LoadDefaultIdentity() error = %v, wantErr %v" , err , tt . wantErr )
return
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "LoadDefaultIdentity() = %v, want %v" , got , tt . want )
}
} )
}
}
func TestIdentity_Kind ( t * testing . T ) {
type fields struct {
Type string
}
tests := [ ] struct {
name string
fields fields
want Type
} {
{ "disabled" , fields { "" } , Disabled } ,
{ "mutualTLS" , fields { "mTLS" } , MutualTLS } ,
2021-04-22 01:07:59 +00:00
{ "tunnelTLS" , fields { "tTLS" } , TunnelTLS } ,
2019-12-12 04:23:44 +00:00
{ "unknown" , fields { "unknown" } , Type ( "unknown" ) } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
i := & Identity {
Type : tt . fields . Type ,
}
if got := i . Kind ( ) ; got != tt . want {
t . Errorf ( "Identity.Kind() = %v, want %v" , got , tt . want )
}
} )
}
}
func TestIdentity_Validate ( t * testing . T ) {
type fields struct {
Type string
Certificate string
Key string
2021-04-22 01:07:59 +00:00
Host string
Root string
2019-12-12 04:23:44 +00:00
}
tests := [ ] struct {
name string
fields fields
wantErr bool
} {
2021-04-22 01:07:59 +00:00
{ "ok mTLS" , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "" , "" } , false } ,
{ "ok tTLS" , fields { "tTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/certs/root_ca.crt" } , false } ,
2019-12-12 04:23:44 +00:00
{ "ok disabled" , fields { } , false } ,
2021-04-22 01:07:59 +00:00
{ "fail type" , fields { "foo" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "" , "" } , true } ,
{ "fail certificate" , fields { "mTLS" , "" , "testdata/identity/identity_key" , "" , "" } , true } ,
{ "fail key" , fields { "mTLS" , "testdata/identity/identity.crt" , "" , "" , "" } , true } ,
{ "fail key" , fields { "tTLS" , "testdata/identity/identity.crt" , "" , "tunnel:443" , "testdata/certs/root_ca.crt" } , true } ,
{ "fail missing certificate" , fields { "mTLS" , "testdata/identity/missing.crt" , "testdata/identity/identity_key" , "" , "" } , true } ,
{ "fail missing certificate" , fields { "tTLS" , "testdata/identity/missing.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/certs/root_ca.crt" } , true } ,
{ "fail missing key" , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/missing_key" , "" , "" } , true } ,
{ "fail missing key" , fields { "tTLS" , "testdata/identity/identity.crt" , "testdata/identity/missing_key" , "tunnel:443" , "testdata/certs/root_ca.crt" } , true } ,
{ "fail host" , fields { "tTLS" , "testdata/identity/identity.crt" , "testdata/identity/missing_key" , "" , "testdata/certs/root_ca.crt" } , true } ,
{ "fail root" , fields { "tTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/certs/missing.crt" } , true } ,
2019-12-12 04:23:44 +00:00
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
i := & Identity {
Type : tt . fields . Type ,
Certificate : tt . fields . Certificate ,
Key : tt . fields . Key ,
2021-04-22 01:07:59 +00:00
Host : tt . fields . Host ,
Root : tt . fields . Root ,
2019-12-12 04:23:44 +00:00
}
if err := i . Validate ( ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "Identity.Validate() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}
func TestIdentity_TLSCertificate ( t * testing . T ) {
expected , err := tls . LoadX509KeyPair ( "testdata/identity/identity.crt" , "testdata/identity/identity_key" )
if err != nil {
t . Fatal ( err )
}
type fields struct {
Type string
Certificate string
Key string
}
tests := [ ] struct {
name string
fields fields
want tls . Certificate
wantErr bool
} {
2021-04-22 01:07:59 +00:00
{ "ok mTLS" , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , expected , false } ,
{ "ok tTLS" , fields { "tTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , expected , false } ,
2019-12-12 04:23:44 +00:00
{ "ok disabled" , fields { } , tls . Certificate { } , false } ,
{ "fail type" , fields { "foo" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , tls . Certificate { } , true } ,
{ "fail certificate" , fields { "mTLS" , "testdata/certs/server.crt" , "testdata/identity/identity_key" } , tls . Certificate { } , true } ,
{ "fail not after" , fields { "mTLS" , "testdata/identity/expired.crt" , "testdata/identity/identity_key" } , tls . Certificate { } , true } ,
{ "fail not before" , fields { "mTLS" , "testdata/identity/not_before.crt" , "testdata/identity/identity_key" } , tls . Certificate { } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
i := & Identity {
Type : tt . fields . Type ,
Certificate : tt . fields . Certificate ,
Key : tt . fields . Key ,
}
got , err := i . TLSCertificate ( )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Identity.TLSCertificate() error = %v, wantErr %v" , err , tt . wantErr )
return
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "Identity.TLSCertificate() = %v, want %v" , got , tt . want )
}
} )
}
}
func Test_fileExists ( t * testing . T ) {
type args struct {
filename string
}
tests := [ ] struct {
name string
args args
wantErr bool
} {
{ "ok" , args { "testdata/identity/identity.crt" } , false } ,
{ "missing" , args { "testdata/identity/missing.crt" } , true } ,
{ "directory" , args { "testdata/identity" } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
if err := fileExists ( tt . args . filename ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "fileExists() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}
2019-12-12 20:23:53 +00:00
func TestWriteDefaultIdentity ( t * testing . T ) {
tmpDir , err := ioutil . TempDir ( os . TempDir ( ) , "go-tests" )
if err != nil {
t . Fatal ( err )
}
oldConfigDir := configDir
oldIdentityDir := identityDir
oldIdentityFile := IdentityFile
defer func ( ) {
configDir = oldConfigDir
identityDir = oldIdentityDir
IdentityFile = oldIdentityFile
os . RemoveAll ( tmpDir )
} ( )
certs , err := pemutil . ReadCertificateBundle ( "testdata/identity/identity.crt" )
if err != nil {
t . Fatal ( err )
}
key , err := pemutil . Read ( "testdata/identity/identity_key" )
if err != nil {
t . Fatal ( err )
}
var certChain [ ] api . Certificate
for _ , c := range certs {
certChain = append ( certChain , api . Certificate { Certificate : c } )
}
2021-11-15 23:32:07 +00:00
configDir = returnInput ( filepath . Join ( tmpDir , "config" ) )
identityDir = returnInput ( filepath . Join ( tmpDir , "identity" ) )
IdentityFile = returnInput ( filepath . Join ( tmpDir , "config" , "identity.json" ) )
2019-12-12 20:23:53 +00:00
type args struct {
certChain [ ] api . Certificate
key crypto . PrivateKey
}
tests := [ ] struct {
name string
prepare func ( )
args args
wantErr bool
} {
{ "ok" , func ( ) { } , args { certChain , key } , false } ,
{ "fail mkdir config" , func ( ) {
2021-11-15 23:32:07 +00:00
configDir = returnInput ( filepath . Join ( tmpDir , "identity" , "identity.crt" ) )
identityDir = returnInput ( filepath . Join ( tmpDir , "identity" ) )
2019-12-12 20:23:53 +00:00
} , args { certChain , key } , true } ,
{ "fail mkdir identity" , func ( ) {
2021-11-15 23:32:07 +00:00
configDir = returnInput ( filepath . Join ( tmpDir , "config" ) )
identityDir = returnInput ( filepath . Join ( tmpDir , "identity" , "identity.crt" ) )
2019-12-12 20:23:53 +00:00
} , args { certChain , key } , true } ,
{ "fail certificate" , func ( ) {
2021-11-15 23:32:07 +00:00
configDir = returnInput ( filepath . Join ( tmpDir , "config" ) )
identityDir = returnInput ( filepath . Join ( tmpDir , "bad-dir" ) )
os . MkdirAll ( identityDir ( ) , 0600 )
2019-12-12 20:23:53 +00:00
} , args { certChain , key } , true } ,
{ "fail key" , func ( ) {
2021-11-15 23:32:07 +00:00
configDir = returnInput ( filepath . Join ( tmpDir , "config" ) )
identityDir = returnInput ( filepath . Join ( tmpDir , "identity" ) )
2019-12-12 20:23:53 +00:00
} , args { certChain , "badKey" } , true } ,
{ "fail write identity" , func ( ) {
2021-11-15 23:32:07 +00:00
configDir = returnInput ( filepath . Join ( tmpDir , "bad-dir" ) )
identityDir = returnInput ( filepath . Join ( tmpDir , "identity" ) )
IdentityFile = returnInput ( filepath . Join ( configDir ( ) , "identity.json" ) )
os . MkdirAll ( configDir ( ) , 0600 )
2019-12-12 20:23:53 +00:00
} , args { certChain , key } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
tt . prepare ( )
if err := WriteDefaultIdentity ( tt . args . certChain , tt . args . key ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "WriteDefaultIdentity() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}
2019-12-18 20:46:46 +00:00
2021-04-22 01:07:59 +00:00
func TestIdentity_GetClientCertificateFunc ( t * testing . T ) {
expected , err := tls . LoadX509KeyPair ( "testdata/identity/identity.crt" , "testdata/identity/identity_key" )
if err != nil {
t . Fatal ( err )
}
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := [ ] struct {
name string
fields fields
want * tls . Certificate
wantErr bool
} {
{ "ok mTLS" , fields { "mtls" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "" , "" } , & expected , false } ,
{ "ok tTLS" , fields { "ttls" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/certs/root_ca.crt" } , & expected , false } ,
{ "fail missing cert" , fields { "mTLS" , "testdata/identity/missing.crt" , "testdata/identity/identity_key" , "" , "" } , nil , true } ,
{ "fail missing key" , fields { "tTLS" , "testdata/identity/identity.crt" , "testdata/identity/missing_key" , "tunnel:443" , "testdata/certs/root_ca.crt" } , nil , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
i := & Identity {
Type : tt . fields . Type ,
Certificate : tt . fields . Certificate ,
Key : tt . fields . Key ,
Host : tt . fields . Host ,
Root : tt . fields . Root ,
}
fn := i . GetClientCertificateFunc ( )
got , err := fn ( & tls . CertificateRequestInfo { } )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Identity.GetClientCertificateFunc() = %v, wantErr %v" , err , tt . wantErr )
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "Identity.GetClientCertificateFunc() = %v, want %v" , got , tt . want )
}
} )
}
}
func TestIdentity_GetCertPool ( t * testing . T ) {
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := [ ] struct {
name string
fields fields
wantSubjects [ ] [ ] byte
wantErr bool
} {
{ "ok" , fields { "ttls" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/certs/root_ca.crt" } , [ ] [ ] byte { [ ] byte ( "0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA" ) } , false } ,
{ "ok nil" , fields { "ttls" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "" } , nil , false } ,
{ "fail missing" , fields { "ttls" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/certs/missing.crt" } , nil , true } ,
{ "fail no cert" , fields { "ttls" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" , "tunnel:443" , "testdata/secrets/root_ca_key" } , nil , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
i := & Identity {
Type : tt . fields . Type ,
Certificate : tt . fields . Certificate ,
Key : tt . fields . Key ,
Host : tt . fields . Host ,
Root : tt . fields . Root ,
}
got , err := i . GetCertPool ( )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Identity.GetCertPool() error = %v, wantErr %v" , err , tt . wantErr )
return
}
if got != nil {
subjects := got . Subjects ( )
if ! reflect . DeepEqual ( subjects , tt . wantSubjects ) {
t . Errorf ( "Identity.GetCertPool() = %x, want %x" , subjects , tt . wantSubjects )
}
}
} )
}
}
2019-12-18 20:46:46 +00:00
type renewer struct {
pool * x509 . CertPool
sign * api . SignResponse
err error
}
func ( r * renewer ) GetRootCAs ( ) * x509 . CertPool {
return r . pool
}
func ( r * renewer ) Renew ( tr http . RoundTripper ) ( * api . SignResponse , error ) {
return r . sign , r . err
}
func TestIdentity_Renew ( t * testing . T ) {
tmpDir , err := ioutil . TempDir ( os . TempDir ( ) , "go-tests" )
if err != nil {
t . Fatal ( err )
}
oldIdentityDir := identityDir
2021-11-15 23:32:07 +00:00
identityDir = returnInput ( "testdata/identity" )
2019-12-18 20:46:46 +00:00
defer func ( ) {
identityDir = oldIdentityDir
os . RemoveAll ( tmpDir )
} ( )
certs , err := pemutil . ReadCertificateBundle ( "testdata/identity/identity.crt" )
if err != nil {
t . Fatal ( err )
}
ok := & renewer {
sign : & api . SignResponse {
ServerPEM : api . Certificate { Certificate : certs [ 0 ] } ,
CaPEM : api . Certificate { Certificate : certs [ 1 ] } ,
CertChainPEM : [ ] api . Certificate {
{ Certificate : certs [ 0 ] } ,
{ Certificate : certs [ 1 ] } ,
} ,
} ,
}
okOld := & renewer {
sign : & api . SignResponse {
ServerPEM : api . Certificate { Certificate : certs [ 0 ] } ,
CaPEM : api . Certificate { Certificate : certs [ 1 ] } ,
} ,
}
fail := & renewer {
err : fmt . Errorf ( "an error" ) ,
}
type fields struct {
Type string
Certificate string
Key string
}
type args struct {
client Renewer
}
tests := [ ] struct {
name string
prepare func ( )
fields fields
args args
wantErr bool
} {
{ "ok" , func ( ) { } , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , args { ok } , false } ,
{ "ok old" , func ( ) { } , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , args { okOld } , false } ,
{ "ok disabled" , func ( ) { } , fields { } , args { nil } , false } ,
{ "fail type" , func ( ) { } , fields { "foo" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , args { ok } , true } ,
{ "fail renew" , func ( ) { } , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , args { fail } , true } ,
{ "fail certificate" , func ( ) { } , fields { "mTLS" , "testdata/certs/server.crt" , "testdata/identity/identity_key" } , args { ok } , true } ,
{ "fail write identity" , func ( ) {
2021-11-15 23:32:07 +00:00
identityDir = returnInput ( filepath . Join ( tmpDir , "bad-dir" ) )
os . MkdirAll ( identityDir ( ) , 0600 )
2019-12-18 20:46:46 +00:00
} , fields { "mTLS" , "testdata/identity/identity.crt" , "testdata/identity/identity_key" } , args { ok } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
tt . prepare ( )
i := & Identity {
Type : tt . fields . Type ,
Certificate : tt . fields . Certificate ,
Key : tt . fields . Key ,
}
if err := i . Renew ( tt . args . client ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "Identity.Renew() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}