package api import ( "bytes" "context" "crypto" "crypto/dsa" //nolint "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/json" "encoding/pem" "fmt" "io" "math/big" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/go-chi/chi" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) const ( rootPEM = `-----BEGIN CERTIFICATE----- MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY /iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx -----END CERTIFICATE-----` certPEM = `-----BEGIN CERTIFICATE----- MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q 5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC 7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+ gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283 TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq 0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` csrPEM = `-----BEGIN CERTIFICATE REQUEST----- MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu 7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2 HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H /qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0 hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x /Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs 0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+ g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8 zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4 2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU 58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` stepCertPEM = `-----BEGIN CERTIFICATE----- MIIChTCCAiugAwIBAgIRAJ3O5T28Rdj2lr/UPjf+GAUwCgYIKoZIzj0EAwIwJDEi MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0xOTAyMjAyMDE1 NDNaFw0xOTAyMjEyMDE1NDNaMHExCzAJBgNVBAYTAlVTMQswCQYDVQQIEwJDQTEW MBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEcMBoGA1UEChMTU21hbGxzdGVwIExhYnMg SW5jLjEfMB0GA1UEAxMWaW50ZXJuYWwuc21hbGxzdGVwLmNvbTBZMBMGByqGSM49 AgEGCCqGSM49AwEHA0IABC0aKrTNl+gXFuNkXisqX4/foLO3VMt+Kphngziim+fz aJhiS9JU+oFYLTNW6HWGUD8CNzfwrmWlVsAmiJwHKlKjgfAwge0wDgYDVR0PAQH/ BAQDAgWgMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4EFgQU JheKvlZqNv1IcgaC8WOS1Zg0i1QwHwYDVR0jBBgwFoAUu97PaFQPfuyKOeew7Hg4 5WFIAVMwIQYDVR0RBBowGIIWaW50ZXJuYWwuc21hbGxzdGVwLmNvbTBZBgwrBgEE AYKkZMYoQAEESTBHAgEBBBVtYXJpYW5vQHNtYWxsc3RlcC5jb20EK2pPMzdkdERi a3UtUW5hYnM1VlIwWXc2WUZGdjl3ZUExOGRwM2h0dmRFanMwCgYIKoZIzj0EAwID SAAwRQIhAIrn17fP5CBrGtKuhyPiq6eSwryBCf8ki+k17u5a+E/LAiB24Y2E0Put nIHOI54lAqDeF7A0y73fPRVCiJEWmuxz0g== -----END CERTIFICATE-----` pubKey = `{ "use": "sig", "kty": "EC", "kid": "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00", "crv": "P-256", "alg": "ES256", "x": "p9QX4tzjxUrB0fgqRWLKUuPolDtBW681f2Qyh-uVNhk", "y": "CNSEloc4oLDFTX0Vywj0WiqOlh516sFQwCj6WtM8LT8" }` privKey = "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiNEhBYjE0WDQ5OFM4LWxSb29JTnpqZyJ9.RbkJXGzI3kOsaP20KmZs0ELFLgpRddAE49AJHlEblw-uH_gg6SV3QA.M3MArEpHgI171lhm.gBlFySpzK9F7riBJbtLSNkb4nAw_gWokqs1jS-ZK1qxuqTK-9mtX5yILjRnftx9P9uFp5xt7rvv4Mgom1Ed4V9WtIyfNP_Cz3Pme1Eanp5nY68WCe_yG6iSB1RJdMDBUb2qBDZiBdhJim1DRXsOfgedOrNi7GGbppMlD77DEpId118owR5izA-c6Q_hg08hIE3tnMAnebDNQoF9jfEY99_AReVRH8G4hgwZEPCfXMTb3J-lowKGG4vXIbK5knFLh47SgOqG4M2M51SMS-XJ7oBz1Vjoamc90QIqKV51rvZ5m0N_sPFtxzcfV4E9yYH3XVd4O-CG4ydVKfKVyMtQ.mcKFZqBHp_n7Ytj2jz9rvw" ) func parseCertificate(data string) *x509.Certificate { block, _ := pem.Decode([]byte(data)) if block == nil { panic("failed to parse certificate PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { panic("failed to parse certificate: " + err.Error()) } return cert } func parseCertificateRequest(data string) *x509.CertificateRequest { block, _ := pem.Decode([]byte(data)) if block == nil { panic("failed to parse certificate request PEM") } csr, err := x509.ParseCertificateRequest(block.Bytes) if err != nil { panic("failed to parse certificate request: " + err.Error()) } return csr } func TestNewCertificate(t *testing.T) { cert := parseCertificate(rootPEM) if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) { t.Errorf("NewCertificate failed, got %v, wants %v", NewCertificate(cert), Certificate{Certificate: cert}) } } func TestCertificate_MarshalJSON(t *testing.T) { type fields struct { Certificate *x509.Certificate } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{Certificate: nil}, []byte("null"), false}, {"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false}, {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false}, {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := Certificate{ Certificate: tt.fields.Certificate, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("Certificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Certificate.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestCertificate_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte wantCert bool wantErr bool }{ {"no data", nil, false, true}, {"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true}, {"empty string", []byte(`""`), false, false}, {"json null", []byte(`null`), false, false}, {"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false}, {"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var c Certificate if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if tt.wantCert && c.Certificate == nil { t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil") } }) } } func TestCertificate_UnmarshalJSON_json(t *testing.T) { tests := []struct { name string data string wantCert bool wantErr bool }{ {"invalid type (bool)", `{"crt":true}`, false, true}, {"invalid type (number)", `{"crt":123}`, false, true}, {"invalid type (object)", `{"crt":{}}`, false, true}, {"empty crt (null)", `{"crt":null}`, false, false}, {"empty crt (string)", `{"crt":""}`, false, false}, {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true}, {"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false}, } type request struct { Cert Certificate `json:"crt"` } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body request if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } switch tt.wantCert { case true: if body.Cert.Certificate == nil { t.Error("json.Unmarshal() failed, Certificate is nil") } case false: if body.Cert.Certificate != nil { t.Error("json.Unmarshal() failed, Certificate is not nil") } } }) } } func TestNewCertificateRequest(t *testing.T) { csr := parseCertificateRequest(csrPEM) if !reflect.DeepEqual(CertificateRequest{CertificateRequest: csr}, NewCertificateRequest(csr)) { t.Errorf("NewCertificateRequest failed, got %v, wants %v", NewCertificateRequest(csr), CertificateRequest{CertificateRequest: csr}) } } func TestCertificateRequest_MarshalJSON(t *testing.T) { type fields struct { CertificateRequest *x509.CertificateRequest } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{CertificateRequest: nil}, []byte("null"), false}, {"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false}, {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := CertificateRequest{ CertificateRequest: tt.fields.CertificateRequest, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CertificateRequest.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestCertificateRequest_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte wantCert bool wantErr bool }{ {"no data", nil, false, true}, {"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true}, {"empty string", []byte(`""`), false, false}, {"json null", []byte(`null`), false, false}, {"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var c CertificateRequest if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if tt.wantCert && c.CertificateRequest == nil { t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil") } }) } } func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) { tests := []struct { name string data string wantCert bool wantErr bool }{ {"invalid type (bool)", `{"csr":true}`, false, true}, {"invalid type (number)", `{"csr":123}`, false, true}, {"invalid type (object)", `{"csr":{}}`, false, true}, {"empty csr (null)", `{"csr":null}`, false, false}, {"empty csr (string)", `{"csr":""}`, false, false}, {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true}, {"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false}, } type request struct { CSR CertificateRequest `json:"csr"` } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body request if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } switch tt.wantCert { case true: if body.CSR.CertificateRequest == nil { t.Error("json.Unmarshal() failed, CertificateRequest is nil") } case false: if body.CSR.CertificateRequest != nil { t.Error("json.Unmarshal() failed, CertificateRequest is not nil") } } }) } } func TestSignRequest_Validate(t *testing.T) { csr := parseCertificateRequest(csrPEM) bad := parseCertificateRequest(csrPEM) bad.Signature[0]++ type fields struct { CsrPEM CertificateRequest OTT string NotBefore time.Time NotAfter time.Time } tests := []struct { name string fields fields err error }{ {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &SignRequest{ CsrPEM: tt.fields.CsrPEM, OTT: tt.fields.OTT, NotAfter: NewTimeDuration(tt.fields.NotAfter), NotBefore: NewTimeDuration(tt.fields.NotBefore), } if err := s.Validate(); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } type mockProvisioner struct { ret1, ret2, ret3 interface{} err error getID func() string getIDForToken func() string getTokenID func(string) (string, error) getName func() string getType func() provisioner.Type getEncryptedKey func() (string, string, bool) init func(provisioner.Config) error authorizeRenew func(ctx context.Context, cert *x509.Certificate) error authorizeRevoke func(ctx context.Context, token string) error authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewal func(*x509.Certificate) error authorizeSSHSign func(ctx context.Context, token string) ([]provisioner.SignOption, error) authorizeSSHRevoke func(ctx context.Context, token string) error authorizeSSHRenew func(ctx context.Context, token string) (*ssh.Certificate, error) authorizeSSHRekey func(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) } func (m *mockProvisioner) GetID() string { if m.getID != nil { return m.getID() } return m.ret1.(string) } func (m *mockProvisioner) GetIDForToken() string { if m.getIDForToken != nil { return m.getIDForToken() } return m.ret1.(string) } func (m *mockProvisioner) GetTokenID(token string) (string, error) { if m.getTokenID != nil { return m.getTokenID(token) } if m.ret1 == nil { return "", m.err } return m.ret1.(string), m.err } func (m *mockProvisioner) GetName() string { if m.getName != nil { return m.getName() } return m.ret1.(string) } func (m *mockProvisioner) GetType() provisioner.Type { if m.getType != nil { return m.getType() } return m.ret1.(provisioner.Type) } func (m *mockProvisioner) GetEncryptedKey() (string, string, bool) { if m.getEncryptedKey != nil { return m.getEncryptedKey() } return m.ret1.(string), m.ret2.(string), m.ret3.(bool) } func (m *mockProvisioner) Init(c provisioner.Config) error { if m.init != nil { return m.init(c) } return m.err } func (m *mockProvisioner) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if m.authorizeRenew != nil { return m.authorizeRenew(ctx, cert) } return m.err } func (m *mockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { if m.authorizeRevoke != nil { return m.authorizeRevoke(ctx, token) } return m.err } func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.authorizeSign != nil { return m.authorizeSign(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockProvisioner) AuthorizeRenewal(c *x509.Certificate) error { if m.authorizeRenewal != nil { return m.authorizeRenewal(c) } return m.err } func (m *mockProvisioner) AuthorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { if m.authorizeSSHSign != nil { return m.authorizeSSHSign(ctx, token) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockProvisioner) AuthorizeSSHRevoke(ctx context.Context, token string) error { if m.authorizeSSHRevoke != nil { return m.authorizeSSHRevoke(ctx, token) } return m.err } func (m *mockProvisioner) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { if m.authorizeSSHRenew != nil { return m.authorizeSSHRenew(ctx, token) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) { if m.authorizeSSHRekey != nil { return m.authorizeSSHRekey(ctx, token) } return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } type mockAuthority struct { ret1, ret2 interface{} err error authorizeSign func(ott string) ([]provisioner.SignOption, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByName func(name string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) revoke func(context.Context, *authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) checkSSHHost func(ctx context.Context, principal, token string) (bool, error) getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) version func() authority.Version } // TODO: remove once Authorize is deprecated. func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return m.AuthorizeSign(ott) } func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { if m.authorizeSign != nil { return m.authorizeSign(ott) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { if m.getTLSOptions != nil { return m.getTLSOptions() } return m.ret1.(*authority.TLSOptions) } func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { if m.root != nil { return m.root(shasum) } return m.ret1.(*x509.Certificate), m.err } func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.sign != nil { return m.sign(cr, opts, signOpts...) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { if m.rekey != nil { return m.rekey(oldcert, pk) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { if m.getProvisioners != nil { return m.getProvisioners(nextCursor, limit) } return m.ret1.(provisioner.List), m.ret2.(string), m.err } func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { if m.loadProvisionerByCertificate != nil { return m.loadProvisionerByCertificate(cert) } return m.ret1.(provisioner.Interface), m.err } func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { if m.loadProvisionerByName != nil { return m.loadProvisionerByName(name) } return m.ret1.(provisioner.Interface), m.err } func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { if m.revoke != nil { return m.revoke(ctx, opts) } return m.err } func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { if m.getEncryptedKey != nil { return m.getEncryptedKey(kid) } return m.ret1.(string), m.err } func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { if m.getRoots != nil { return m.getRoots() } return m.ret1.([]*x509.Certificate), m.err } func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { if m.getFederation != nil { return m.getFederation() } return m.ret1.([]*x509.Certificate), m.err } func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.signSSH != nil { return m.signSSH(ctx, key, opts, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.signSSHAddUser != nil { return m.signSSHAddUser(ctx, key, cert) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.renewSSH != nil { return m.renewSSH(ctx, cert) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.rekeySSH != nil { return m.rekeySSH(ctx, cert, key, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { if m.getSSHHosts != nil { return m.getSSHHosts(ctx, cert) } return m.ret1.([]authority.Host), m.err } func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHRoots != nil { return m.getSSHRoots(ctx) } return m.ret1.(*authority.SSHKeys), m.err } func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHFederation != nil { return m.getSSHFederation(ctx) } return m.ret1.(*authority.SSHKeys), m.err } func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { if m.getSSHConfig != nil { return m.getSSHConfig(ctx, typ, data) } return m.ret1.([]templates.Output), m.err } func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { if m.checkSSHHost != nil { return m.checkSSHHost(ctx, principal, token) } return m.ret1.(bool), m.err } func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { if m.getSSHBastion != nil { return m.getSSHBastion(ctx, user, hostname) } return m.ret1.(*authority.Bastion), m.err } func (m *mockAuthority) Version() authority.Version { if m.version != nil { return m.version() } return m.ret1.(authority.Version) } func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority } type args struct { r Router } tests := []struct { name string fields fields args args }{ {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.Route(tt.args.r) }) } } func Test_caHandler_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() h := New(&mockAuthority{}).(*caHandler) h.Health(w, req) res := w.Result() if res.StatusCode != 200 { t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Health unexpected error = %v", err) } expected := []byte("{\"status\":\"ok\"}\n") if !bytes.Equal(body, expected) { t.Errorf("caHandler.Health Body = %s, wants %s", body, expected) } } func Test_caHandler_Root(t *testing.T) { tests := []struct { name string root *x509.Certificate err error statusCode int }{ {"ok", parseCertificate(rootPEM), nil, 200}, {"fail", nil, fmt.Errorf("not found"), 404}, } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("sha", "efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36") req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) w := httptest.NewRecorder() h.Root(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode == 200 { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Sign(t *testing.T) { csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "foobarzar", }) if err != nil { t.Fatal(err) } invalid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "", }) if err != nil { t.Fatal(err) } expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) tests := []struct { name string input string certAttrOpts []provisioner.SignOption autherr error cert *x509.Certificate root *x509.Certificate signErr error statusCode int expected []byte }{ {"ok", string(valid), nil, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected1}, {"ok with Provisioner", string(valid), nil, nil, parseCertificate(stepCertPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected2}, {"json read error", "{", nil, nil, nil, nil, nil, http.StatusBadRequest, nil}, {"validate error", string(invalid), nil, nil, nil, nil, nil, http.StatusBadRequest, nil}, {"authorize error", string(valid), nil, fmt.Errorf("an error"), nil, nil, nil, http.StatusUnauthorized, nil}, {"sign error", string(valid), nil, nil, nil, nil, fmt.Errorf("an error"), http.StatusForbidden, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *authority.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) w := httptest.NewRecorder() h.Sign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.expected) } } }) } } func Test_caHandler_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, } expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls w := httptest.NewRecorder() h.Renew(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Renew unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Rekey(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(RekeyRequest{ CsrPEM: CertificateRequest{csr}, }) if err != nil { t.Fatal(err) } tests := []struct { name string input string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", string(valid), cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no tls", string(valid), nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", string(valid), &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, {"rekey error", string(valid), cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, {"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest}, } expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req.TLS = tt.tls w := httptest.NewRecorder() h.Rekey(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Rekey unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Rekey Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Provisioners(t *testing.T) { type fields struct { Authority Authority } type args struct { w http.ResponseWriter r *http.Request } req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", http.NoBody) if err != nil { t.Fatal(err) } reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", http.NoBody) if err != nil { t.Fatal(err) } var key jose.JSONWebKey if err := json.Unmarshal([]byte(pubKey), &key); err != nil { t.Fatal(err) } p := provisioner.List{ &provisioner.JWK{ Type: "JWK", Name: "max", EncryptedKey: "abc", Key: &key, }, &provisioner.JWK{ Type: "JWK", Name: "mariano", EncryptedKey: "def", Key: &key, }, } pr := ProvisionersResponse{ Provisioners: p, } tests := []struct { name string fields fields args args statusCode int }{ {"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, {"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, {"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, } expected, err := json.Marshal(pr) if err != nil { t.Fatal(err) } expectedError400 := errs.BadRequest("limit 'abc' is not an integer") expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) expectedError500 := errs.InternalServer("force") expectedError500Bytes, err := json.Marshal(expectedError500) assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.Provisioners(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Provisioners unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { switch tt.statusCode { case 400: if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes) } case 500: if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes) } default: t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode) } } }) } } func Test_caHandler_ProvisionerKey(t *testing.T) { type fields struct { Authority Authority } type args struct { w http.ResponseWriter r *http.Request } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("kid", "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00") req := httptest.NewRequest("GET", "http://example.com/provisioners/oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00/encrypted-key", nil) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) tests := []struct { name string fields fields args args statusCode int }{ {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, } expected := []byte(`{"key":"` + privKey + `"}`) expectedError404 := errs.NotFound("force") expectedError404Bytes, err := json.Marshal(expectedError404) assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.ProvisionerKey(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Provisioners unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes) } } }) } } func Test_caHandler_Roots(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() h.Roots(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Roots unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() h.Federation(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Federation unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected) } } }) } } func Test_fmtPublicKey(t *testing.T) { p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { t.Fatal(err) } edPub, edPriv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } var dsa2048 dsa.PrivateKey if err := dsa.GenerateParameters(&dsa2048.Parameters, rand.Reader, dsa.L2048N256); err != nil { t.Fatal(err) } if err := dsa.GenerateKey(&dsa2048, rand.Reader); err != nil { t.Fatal(err) } type args struct { pub, priv interface{} cert *x509.Certificate } tests := []struct { name string args args want string }{ {"p256", args{p256.Public(), p256, nil}, "ECDSA P-256"}, {"rsa1024", args{rsa1024.Public(), rsa1024, nil}, "RSA 1024"}, {"ed25519", args{edPub, edPriv, nil}, "Ed25519"}, {"dsa2048", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.DSA, PublicKey: &dsa2048.PublicKey}}, "DSA 2048"}, {"unknown", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.ECDSA, PublicKey: []byte("12345678")}}, "ECDSA unknown"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var cert *x509.Certificate if tt.args.cert != nil { cert = tt.args.cert } else { cert = mustCertificate(t, tt.args.pub, tt.args.priv) } if got := fmtPublicKey(cert); got != tt.want { t.Errorf("fmtPublicKey() = %v, want %v", got, tt.want) } }) } } func mustCertificate(t *testing.T, pub, priv interface{}) *x509.Certificate { template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Acme Co"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } der, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) if err != nil { t.Fatal(err) } cert, err := x509.ParseCertificate(der) if err != nil { t.Fatal(err) } return cert }