From baf3c40fef7f102e40d224891f9744f2b65d3b77 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Mon, 21 Mar 2022 16:55:09 -0700 Subject: [PATCH 01/20] Print some basic configuration info on startup --- authority/authority.go | 10 ++++++++-- ca/ca.go | 3 +++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index cc26635e..516c8130 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -294,8 +294,6 @@ func (a *Authority) init() error { return err } a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate) - sum := sha256.Sum256(resp.RootCertificate.Raw) - log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:])) } } @@ -313,6 +311,7 @@ func (a *Authority) init() error { for _, crt := range a.rootX509Certs { sum := sha256.Sum256(crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), crt) + log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) } a.rootX509CertPool = x509.NewCertPool() @@ -541,6 +540,13 @@ func (a *Authority) init() error { a.templates.Data["Step"] = tmplVars } + if tmplVars.SSH.HostKey != nil { + log.Printf("SSH Host CA Key: %s\n", ssh.MarshalAuthorizedKey(tmplVars.SSH.HostKey)) + } + if tmplVars.SSH.HostKey != nil { + log.Printf("SSH User CA Key: %s\n", ssh.MarshalAuthorizedKey(tmplVars.SSH.UserKey)) + } + // JWT numeric dates are seconds. a.startTime = time.Now().Truncate(time.Second) // Set flag indicating that initialization has been completed, and should diff --git a/ca/ca.go b/ca/ca.go index c95ba22f..3be03e34 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -288,6 +288,9 @@ func (ca *CA) Run() error { var wg sync.WaitGroup errs := make(chan error, 1) + log.Printf("Documentation: https://u.step.sm/docs/ca") + log.Printf("Config File: %s", ca.opts.configFile) + if ca.insecureSrv != nil { wg.Add(1) go func() { From 91a25b52bdc3a3ee06c0fdae717641766b73aa94 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Mon, 21 Mar 2022 16:59:28 -0700 Subject: [PATCH 02/20] Print discord --- ca/ca.go | 1 + 1 file changed, 1 insertion(+) diff --git a/ca/ca.go b/ca/ca.go index 3be03e34..2751d050 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -289,6 +289,7 @@ func (ca *CA) Run() error { errs := make(chan error, 1) log.Printf("Documentation: https://u.step.sm/docs/ca") + log.Printf("Community Discord: https://u.step.sm/discord") log.Printf("Config File: %s", ca.opts.configFile) if ca.insecureSrv != nil { From 91be50cf70ef7ca1a2a8f92bf5f6f4e655bdafe1 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Mon, 21 Mar 2022 19:55:21 -0700 Subject: [PATCH 03/20] Add --quiet flag --- ca/ca.go | 18 +++++++++++++++--- commands/app.go | 8 +++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 2751d050..9c852a19 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -31,6 +31,7 @@ import ( type options struct { configFile string linkedCAToken string + quiet bool password []byte issuerPassword []byte sshHostPassword []byte @@ -101,6 +102,14 @@ func WithLinkedCAToken(token string) Option { } } +// WithQuiet sets the quiet flag. +func WithQuiet(quiet bool) Option { + return func(o *options) { + o.quiet = quiet + } +} + + // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { @@ -288,9 +297,11 @@ func (ca *CA) Run() error { var wg sync.WaitGroup errs := make(chan error, 1) - log.Printf("Documentation: https://u.step.sm/docs/ca") - log.Printf("Community Discord: https://u.step.sm/discord") - log.Printf("Config File: %s", ca.opts.configFile) + if !ca.opts.quiet { + log.Printf("Documentation: https://u.step.sm/docs/ca") + log.Printf("Community Discord: https://u.step.sm/discord") + log.Printf("Config File: %s", ca.opts.configFile) + } if ca.insecureSrv != nil { wg.Add(1) @@ -359,6 +370,7 @@ func (ca *CA) Reload() error { WithSSHUserPassword(ca.opts.sshUserPassword), WithIssuerPassword(ca.opts.issuerPassword), WithLinkedCAToken(ca.opts.linkedCAToken), + WithQuiet(ca.opts.quiet), WithConfigFile(ca.opts.configFile), WithDatabase(ca.auth.GetDatabase()), ) diff --git a/commands/app.go b/commands/app.go index 8c40de0e..47fb1444 100644 --- a/commands/app.go +++ b/commands/app.go @@ -57,6 +57,10 @@ certificate issuer private key used in the RA mode.`, Usage: "token used to enable the linked ca.", EnvVar: "STEP_CA_TOKEN", }, + cli.BoolFlag{ + Name: "quiet", + Usage: "disable startup information", + }, }, } @@ -68,6 +72,7 @@ func appAction(ctx *cli.Context) error { issuerPassFile := ctx.String("issuer-password-file") resolver := ctx.String("resolver") token := ctx.String("token") + quiet := ctx.Bool("quiet") // If zero cmd line args show help, if >1 cmd line args show error. if ctx.NArg() == 0 { @@ -141,7 +146,8 @@ To get a linked authority token: ca.WithSSHHostPassword(sshHostPassword), ca.WithSSHUserPassword(sshUserPassword), ca.WithIssuerPassword(issuerPassword), - ca.WithLinkedCAToken(token)) + ca.WithLinkedCAToken(token), + ca.WithQuiet(quiet)) if err != nil { fatal(err) } From 25cc9a172835207680d2fee3c13690a5688169f4 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Tue, 22 Mar 2022 07:38:09 -0700 Subject: [PATCH 04/20] Update authority/authority.go Co-authored-by: Herman Slatman --- authority/authority.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authority/authority.go b/authority/authority.go index 516c8130..50025cce 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -543,7 +543,7 @@ func (a *Authority) init() error { if tmplVars.SSH.HostKey != nil { log.Printf("SSH Host CA Key: %s\n", ssh.MarshalAuthorizedKey(tmplVars.SSH.HostKey)) } - if tmplVars.SSH.HostKey != nil { + if tmplVars.SSH.UserKey != nil { log.Printf("SSH User CA Key: %s\n", ssh.MarshalAuthorizedKey(tmplVars.SSH.UserKey)) } From f20784be56943dbb430f588fec761af61530ccd3 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Tue, 22 Mar 2022 10:41:16 -0700 Subject: [PATCH 05/20] format --- ca/ca.go | 1 - commands/app.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 9c852a19..41f48483 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -109,7 +109,6 @@ func WithQuiet(quiet bool) Option { } } - // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { diff --git a/commands/app.go b/commands/app.go index 47fb1444..984ce067 100644 --- a/commands/app.go +++ b/commands/app.go @@ -58,7 +58,7 @@ certificate issuer private key used in the RA mode.`, EnvVar: "STEP_CA_TOKEN", }, cli.BoolFlag{ - Name: "quiet", + Name: "quiet", Usage: "disable startup information", }, }, From 955d4cf80d1cdcd697b08694c53d57d6effd2ee9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 28 Mar 2022 17:54:35 -0700 Subject: [PATCH 06/20] Add authority.WithX509SignerFunc This change adds a new authority option that allows to pass a callback that returns the certificate chain and signer used to sign X.509 certificates. This option will be used by Caddy, they renew the intermediate certificate weekly and there's no other way to replace it without re-creating the embedded CA. Fixes #874 --- authority/options.go | 16 ++++ cas/apiv1/options.go | 9 ++- cas/softcas/softcas.go | 58 +++++++++---- cas/softcas/softcas_test.go | 157 +++++++++++++++++++++++++++++------- 4 files changed, 194 insertions(+), 46 deletions(-) diff --git a/authority/options.go b/authority/options.go index a1238b1d..1c154577 100644 --- a/authority/options.go +++ b/authority/options.go @@ -163,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option { } } +// WithX509SignerFunc defines the function used to get the chain of certificates +// and signer used when we sign X.509 certificates. +func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option { + return func(a *Authority) error { + srv, err := cas.New(context.Background(), casapi.Options{ + Type: casapi.SoftCAS, + CertificateSigner: fn, + }) + if err != nil { + return err + } + a.x509CAService = srv + return nil + } +} + // WithSSHUserSigner defines the signer used to sign SSH user certificates. func WithSSHUserSigner(s crypto.Signer) Option { return func(a *Authority) error { diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index badad7fc..408c5f96 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -31,13 +31,18 @@ type Options struct { // https://cloud.google.com/docs/authentication. CredentialsFile string `json:"credentialsFile,omitempty"` - // Certificate and signer are the issuer certificate, along with any other - // bundled certificates to be returned in the chain for consumers, and + // CertificateChain and Signer are the issuer certificate, along with any + // other bundled certificates to be returned in the chain for consumers, and // signer used in SoftCAS. They are configured in ca.json crt and key // properties. CertificateChain []*x509.Certificate `json:"-"` Signer crypto.Signer `json:"-"` + // CertificateSigner combines CertificateChain and Signer in a callback that + // returns the chain of certificate and signer used to sign X.509 + // certificates in SoftCAS. + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) `json:"-"` + // IsCreator is set to true when we're creating a certificate authority. It // is used to skip some validations when initializing a // CertificateAuthority. This option is used on SoftCAS and CloudCAS. diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 8e67d016..2a97145b 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -24,9 +24,10 @@ var now = time.Now // SoftCAS implements a Certificate Authority Service using Golang or KMS // crypto. This is the default CAS used in step-ca. type SoftCAS struct { - CertificateChain []*x509.Certificate - Signer crypto.Signer - KeyManager kms.KeyManager + CertificateChain []*x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) + KeyManager kms.KeyManager } // New creates a new CertificateAuthorityService implementation using Golang or KMS @@ -34,16 +35,17 @@ type SoftCAS struct { func New(ctx context.Context, opts apiv1.Options) (*SoftCAS, error) { if !opts.IsCreator { switch { - case len(opts.CertificateChain) == 0: + case len(opts.CertificateChain) == 0 && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'CertificateChain' cannot be nil") - case opts.Signer == nil: + case opts.Signer == nil && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'signer' cannot be nil") } } return &SoftCAS{ - CertificateChain: opts.CertificateChain, - Signer: opts.Signer, - KeyManager: opts.KeyManager, + CertificateChain: opts.CertificateChain, + Signer: opts.Signer, + CertificateSigner: opts.CertificateSigner, + KeyManager: opts.KeyManager, }, nil } @@ -57,6 +59,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 } t := now() + // Provisioners can also set specific values. if req.Template.NotBefore.IsZero() { req.Template.NotBefore = t.Add(-1 * req.Backdate) @@ -64,16 +67,21 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 if req.Template.NotAfter.IsZero() { req.Template.NotAfter = t.Add(req.Lifetime) } - req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + chain, signer, err := c.getCertSigner() + if err != nil { + return nil, err + } + req.Template.Issuer = chain[0].Subject + + cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -89,16 +97,21 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R t := now() req.Template.NotBefore = t.Add(-1 * req.Backdate) req.Template.NotAfter = t.Add(req.Lifetime) - req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + chain, signer, err := c.getCertSigner() + if err != nil { + return nil, err + } + req.Template.Issuer = chain[0].Subject + + cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.RenewCertificateResponse{ Certificate: cert, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -106,9 +119,13 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R // operation is a no-op as the actual revoke will happen when we store the entry // in the db. func (c *SoftCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { + chain, _, err := c.getCertSigner() + if err != nil { + return nil, err + } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -179,7 +196,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori }, nil } -// initializeKeyManager initiazes the default key manager if was not given. +// initializeKeyManager initializes the default key manager if was not given. func (c *SoftCAS) initializeKeyManager() (err error) { if c.KeyManager == nil { c.KeyManager, err = kms.New(context.Background(), kmsapi.Options{ @@ -189,6 +206,15 @@ func (c *SoftCAS) initializeKeyManager() (err error) { return } +// getCertSigner returns the certificate chain and signer to use. +func (c *SoftCAS) getCertSigner() ([]*x509.Certificate, crypto.Signer, error) { + if c.CertificateSigner != nil { + return c.CertificateSigner() + } + return c.CertificateChain, c.Signer, nil + +} + // createKey uses the configured kms to create a key. func (c *SoftCAS) createKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { if err := c.initializeKeyManager(); err != nil { diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index 7d3add4f..b4f5b440 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -73,6 +73,12 @@ var ( testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) + testCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { + return []*x509.Certificate{testIssuer}, testSigner, nil + } + testFailCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { + return nil, nil, errTest + } ) type signatureAlgorithmSigner struct { @@ -186,6 +192,10 @@ func setTeeReader(t *testing.T, w *bytes.Buffer) { } func TestNew(t *testing.T) { + assertEqual := func(x, y interface{}) bool { + return reflect.DeepEqual(x, y) || fmt.Sprintf("%#v", x) == fmt.Sprintf("%#v", y) + } + type args struct { ctx context.Context opts apiv1.Options @@ -197,6 +207,7 @@ func TestNew(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false}, + {"ok with callback", args{context.Background(), apiv1.Options{CertificateSigner: testCertificateSigner}}, &SoftCAS{CertificateSigner: testCertificateSigner}, false}, {"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true}, {"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true}, } @@ -207,7 +218,7 @@ func TestNew(t *testing.T) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { + if !assertEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) @@ -265,8 +276,9 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { } type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.CreateCertificateRequest @@ -278,43 +290,53 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { want *apiv1.CreateCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &saTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with notBefore", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok with notBefore+notAfter", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with notBefore+notAfter", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplWithLifetime, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"fail template", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, - {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, - {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.CreateCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, + {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, + {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.CreateCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -345,8 +367,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { } type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RenewCertificateRequest @@ -358,30 +381,40 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { want *apiv1.RenewCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, - {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, - {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ + {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, + {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, + {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -397,8 +430,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { func TestSoftCAS_RevokeCertificate(t *testing.T) { type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RevokeCertificateRequest @@ -410,7 +444,7 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) { want *apiv1.RevokeCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, @@ -418,23 +452,37 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) { Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok no cert", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ + {"ok no cert", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok empty", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ + {"ok empty", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + Reason: "test reason", + ReasonCode: 1, + }}, &apiv1.RevokeCertificateResponse{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + Reason: "test reason", + ReasonCode: 1, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -609,3 +657,56 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { }) } } + +func TestSoftCAS_defaultKeyManager(t *testing.T) { + mockNow(t) + type args struct { + req *apiv1.CreateCertificateAuthorityRequest + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok root", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + SerialNumber: big.NewInt(1234), + }, + Lifetime: 24 * time.Hour, + }}, false}, + {"ok intermediate", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: testSigner, + }, + }}, false}, + {"fail with default key manager", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: &badSigner{}, + }, + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &SoftCAS{} + _, err := c.CreateCertificateAuthority(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} From c480936ba4c4f2a988768558d88658c88ebfaff5 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 29 Mar 2022 12:02:17 -0700 Subject: [PATCH 07/20] Split comments. --- cas/apiv1/options.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index 408c5f96..50c3a2be 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -31,12 +31,15 @@ type Options struct { // https://cloud.google.com/docs/authentication. CredentialsFile string `json:"credentialsFile,omitempty"` - // CertificateChain and Signer are the issuer certificate, along with any - // other bundled certificates to be returned in the chain for consumers, and - // signer used in SoftCAS. They are configured in ca.json crt and key - // properties. + // CertificateChain contains the issuer certificate, along with any other + // bundled certificates to be returned in the chain for consumers. It is + // used used in SoftCAS, and is configured in the crt property of the + // ca.json. CertificateChain []*x509.Certificate `json:"-"` - Signer crypto.Signer `json:"-"` + + // Signer is the private key or a KMS signer for the issuer certificate. It is used in + // SoftCAS and it is configured in the key property of the ca.json. + Signer crypto.Signer `json:"-"` // CertificateSigner combines CertificateChain and Signer in a callback that // returns the chain of certificate and signer used to sign X.509 From abf5fc32a3417ca95800bf1d4c8cef2030cf8782 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 29 Mar 2022 14:26:17 -0700 Subject: [PATCH 08/20] Format comment. --- cas/apiv1/options.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index 50c3a2be..3fc34208 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -32,13 +32,13 @@ type Options struct { CredentialsFile string `json:"credentialsFile,omitempty"` // CertificateChain contains the issuer certificate, along with any other - // bundled certificates to be returned in the chain for consumers. It is - // used used in SoftCAS, and is configured in the crt property of the - // ca.json. + // bundled certificates to be returned in the chain to consumers. It is used + // used in SoftCAS and it is configured in the crt property of the ca.json. CertificateChain []*x509.Certificate `json:"-"` - // Signer is the private key or a KMS signer for the issuer certificate. It is used in - // SoftCAS and it is configured in the key property of the ca.json. + // Signer is the private key or a KMS signer for the issuer certificate. It + // is used in SoftCAS and it is configured in the key property of the + // ca.json. Signer crypto.Signer `json:"-"` // CertificateSigner combines CertificateChain and Signer in a callback that From 00634fb648e561e9eba6eb12ecd1ebf4001daa79 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Wed, 30 Mar 2022 11:22:22 +0300 Subject: [PATCH 09/20] api/render, api/log: initial implementation of the packages (#860) * api/render: initial implementation of the package * acme/api: refactored to support api/render * authority/admin: refactored to support api/render * ca: refactored to support api/render * api: refactored to support api/render * api/render: implemented Error * api: refactored to support api/render.Error * acme/api: refactored to support api/render.Error * authority/admin: refactored to support api/render.Error * ca: refactored to support api/render.Error * ca: fixed broken tests * api/render, api/log: moved error logging to this package * acme: refactored Error so that it implements render.RenderableError * authority/admin: refactored Error so that it implements render.RenderableError * api/render: implemented RenderableError * api/render: added test coverage for Error * api/render: implemented statusCodeFromError * api: refactored RootsPEM to work with render.Error * acme, authority/admin: fixed pointer receiver name for consistency * api/render, errs: moved StatusCoder & StackTracer to the render package --- acme/api/account.go | 47 +++++---- acme/api/handler.go | 38 +++---- acme/api/middleware.go | 87 +++++++-------- acme/api/order.go | 56 +++++----- acme/api/revoke.go | 39 +++---- acme/errors.go | 29 +---- api/api.go | 29 ++--- api/errors.go | 62 ----------- api/log/log.go | 40 ++++++- api/rekey.go | 11 +- api/render/render.go | 122 ++++++++++++++++++++++ api/render/render_test.go | 115 ++++++++++++++++++++ api/renew.go | 7 +- api/revoke.go | 15 +-- api/sign.go | 11 +- api/ssh.go | 63 +++++------ api/sshRekey.go | 18 ++-- api/sshRenew.go | 16 +-- api/sshRevoke.go | 11 +- api/utils.go | 57 ---------- api/utils_test.go | 53 ---------- authority/admin/api/acme.go | 16 +-- authority/admin/api/admin.go | 33 +++--- authority/admin/api/middleware.go | 8 +- authority/admin/api/provisioner.go | 54 +++++----- authority/admin/errors.go | 31 +----- authority/authorize_test.go | 49 +++++---- authority/provisioner/acme_test.go | 17 +-- authority/provisioner/aws_test.go | 25 ++--- authority/provisioner/azure_test.go | 25 ++--- authority/provisioner/gcp_test.go | 23 ++-- authority/provisioner/jwk_test.go | 34 +++--- authority/provisioner/k8sSA_test.go | 36 ++++--- authority/provisioner/oidc_test.go | 33 +++--- authority/provisioner/provisioner_test.go | 11 +- authority/provisioner/sshpop_test.go | 29 ++--- authority/provisioner/x5c_test.go | 36 ++++--- authority/provisioners_test.go | 12 +-- authority/root_test.go | 9 +- authority/ssh_test.go | 23 ++-- authority/tls_test.go | 36 ++++--- ca/acmeClient_test.go | 51 ++++----- ca/bootstrap_test.go | 11 +- ca/client_test.go | 74 ++++++------- errs/error.go | 19 ++-- 45 files changed, 859 insertions(+), 762 deletions(-) delete mode 100644 api/errors.go create mode 100644 api/render/render.go create mode 100644 api/render/render_test.go delete mode 100644 api/utils.go delete mode 100644 api/utils_test.go diff --git a/acme/api/account.go b/acme/api/account.go index 0dc8ab40..ade51aef 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -5,8 +5,9 @@ import ( "net/http" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/logging" ) @@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { // Something went wrong ... - api.WriteError(w, err) + render.Error(w, err) return } // Account does not exist // if nar.OnlyReturnExisting { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist")) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } eak, err := h.validateExternalAccountBinding(ctx, &nar) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Status: acme.StatusValid, } if err := h.db.CreateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) + render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response err := eak.BindTo(acc) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) + render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } acc.ExternalAccountBinding = nar.ExternalAccountBinding @@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) - api.JSONStatus(w, acc, httpStatus) + render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. @@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(uar.Status) > 0 || len(uar.Contact) > 0 { @@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } if err := h.db.UpdateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } @@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) - api.JSON(w, acc) + render.JSON(w, acc) } func logOrdersByAccount(w http.ResponseWriter, oids []string) { @@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } h.linker.LinkOrdersByAccountID(ctx, orders) - api.JSON(w, orders) + render.JSON(w, orders) logOrdersByAccount(w, orders) } diff --git a/acme/api/handler.go b/acme/api/handler.go index c3a481f9..10eb22cb 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -12,8 +12,10 @@ import ( "time" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" ) @@ -181,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.JSON(w, &Directory{ + render.JSON(w, &Directory{ NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), @@ -200,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. @@ -208,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return } if acc.ID != az.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } if err = az.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } h.linker.LinkAuthorization(ctx, az) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) - api.JSON(w, az) + render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. @@ -237,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. _, err = payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -257,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { azID := chi.URLParam(r, "authzID") ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return } ch.AuthorizationID = azID if acc.ID != ch.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } @@ -280,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) - api.JSON(w, ch) + render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. @@ -288,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } certID := chi.URLParam(r, "certID") cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return } if cert.AccountID != acc.ID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own certificate '%s'", acc.ID, certID)) return } diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 0cdeaabb..10f7841f 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -10,13 +10,14 @@ import ( "strings" "github.com/go-chi/chi" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { nonce, err := h.db.CreateNonce(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } w.Header().Set("Replay-Nonce", string(nonce)) @@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) + render.Error(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) @@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } @@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) + render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) + render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } @@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // For NewAccount and Revoke requests ... break case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -290,17 +291,17 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := h.ca.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) @@ -315,11 +316,11 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { ctx := r.Context() ok, err := h.prerequisitesChecker(ctx) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return } if !ok { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return } next(w, r.WithContext(ctx)) @@ -334,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return @@ -351,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -376,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -412,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ @@ -443,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !payload.isPostAsGet { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) diff --git a/acme/api/order.go b/acme/api/order.go index 9cf2c1eb..99eb0e95 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -11,9 +11,11 @@ import ( "time" "github.com/go-chi/chi" - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api/render" ) // NewOrderRequest represents the body for a NewOrder request. @@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { Status: acme.StatusPending, } if err := h.newAuthorization(ctx, az); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o.AuthorizationIDs[i] = az.ID @@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } if err := h.db.CreateOrder(ctx, o); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) + render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSONStatus(w, o, http.StatusCreated) + render.JSONStatus(w, o, http.StatusCreated) } func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { @@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. @@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) + render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // challengeTypes determines the types of challenges that should be used diff --git a/acme/api/revoke.go b/acme/api/revoke.go index d01e401c..4b71bc22 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -10,13 +10,14 @@ import ( "net/http" "strings" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ocsp" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" - "go.step.sm/crypto/jose" - "golang.org/x/crypto/ocsp" ) type revokePayload struct { @@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var p revokePayload err = json.Unmarshal(payload.value, &p) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) + render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload")) return } certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) if err != nil { // in this case the most likely cause is a client that didn't properly encode the certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) return } certToBeRevoked, err := x509.ParseCertificate(certBytes) if err != nil { // in this case a client may have encoded something different than a certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) return } serial := certToBeRevoked.SerialNumber.String() dbCert, err := h.db.GetCertificateBySerial(ctx, serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { // this should never happen - api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) + render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal")) return } if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } } else { @@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) + render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return } if hasBeenRevokedBefore { - api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) + render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) return } reasonCode := p.ReasonCode acmeErr := validateReasonCode(reasonCode) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } @@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) err = prov.AuthorizeRevoke(ctx, "") if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) + render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) return } options := revokeOptions(serial, certToBeRevoked, reasonCode) err = h.ca.Revoke(ctx, options) if err != nil { - api.WriteError(w, wrapRevokeErr(err)) + render.Error(w, wrapRevokeErr(err)) return } diff --git a/acme/errors.go b/acme/errors.go index a5c820ba..05888c24 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -3,13 +3,10 @@ package acme import ( "encoding/json" "fmt" - "log" "net/http" - "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the ACME problem. @@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) { return string(b), nil } -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err *Error) { +// Render implements render.RenderableError for Error. +func (e *Error) Render(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/problem+json") - w.WriteHeader(err.StatusCode()) - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err.Err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.Err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Println(err) - } + render.JSONStatus(w, e, e.StatusCode()) } diff --git a/api/api.go b/api/api.go index 1c47c03d..da6309fd 100644 --- a/api/api.go +++ b/api/api.go @@ -22,6 +22,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -284,7 +285,7 @@ func (h *caHandler) Route(r Router) { // Version is an HTTP handler that returns the version of the server. func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { v := h.Authority.Version() - JSON(w, VersionResponse{ + render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, }) @@ -292,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { // Health is an HTTP handler that returns the status of the server. func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { - JSON(w, HealthResponse{Status: "ok"}) + render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root @@ -303,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { - WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) + render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return } - JSON(w, &RootResponse{RootPEM: Certificate{cert}}) + render.JSON(w, &RootResponse{RootPEM: Certificate{cert}}) } func certChainToPEM(certChain []*x509.Certificate) []Certificate { @@ -322,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } p, next, err := h.Authority.GetProvisioners(cursor, limit) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &ProvisionersResponse{ + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -342,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := h.Authority.GetEncryptedKey(kid) if err != nil { - WriteError(w, errs.NotFoundErr(err)) + render.Error(w, errs.NotFoundErr(err)) return } - JSON(w, &ProvisionerKeyResponse{key}) + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return } @@ -361,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{roots[i]} } - JSONStatus(w, &RootsResponse{ + render.JSONStatus(w, &RootsResponse{ Certificates: certs, }, http.StatusCreated) } @@ -370,7 +371,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } @@ -393,7 +394,7 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return } @@ -402,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{federated[i]} } - JSONStatus(w, &FederationResponse{ + render.JSONStatus(w, &FederationResponse{ Certificates: certs, }, http.StatusCreated) } diff --git a/api/errors.go b/api/errors.go deleted file mode 100644 index 680e6578..00000000 --- a/api/errors.go +++ /dev/null @@ -1,62 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - - "github.com/pkg/errors" - - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api/log" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" -) - -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err error) { - switch k := err.(type) { - case *acme.Error: - acme.WriteError(w, k) - return - case *admin.Error: - admin.WriteError(w, k) - return - } - - cause := errors.Cause(err) - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } else if e, ok := cause.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - code := http.StatusInternalServerError - if sc, ok := err.(errs.StatusCoder); ok { - code = sc.StatusCode() - } else if sc, ok := cause.(errs.StatusCoder); ok { - code = sc.StatusCode() - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Error(w, err) - } -} diff --git a/api/log/log.go b/api/log/log.go index 78dae506..cb31410b 100644 --- a/api/log/log.go +++ b/api/log/log.go @@ -2,22 +2,54 @@ package log import ( + "fmt" "log" "net/http" + "os" + + "github.com/pkg/errors" "github.com/smallstep/certificates/logging" ) +// StackTracedError is the set of errors implementing the StackTrace function. +// +// Errors implementing this interface have their stack traces logged when passed +// to the Error function of this package. +type StackTracedError interface { + error + + StackTrace() errors.StackTrace +} + // Error adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. func Error(rw http.ResponseWriter, err error) { - if rl, ok := rw.(logging.ResponseLogger); ok { + rl, ok := rw.(logging.ResponseLogger) + if !ok { + log.Println(err) + + return + } + + rl.WithFields(map[string]interface{}{ + "error": err, + }) + + if os.Getenv("STEPDEBUG") != "1" { + return + } + + e, ok := err.(StackTracedError) + if !ok { + e, ok = errors.Cause(err).(StackTracedError) + } + + if ok { rl.WithFields(map[string]interface{}{ - "error": err, + "stack-trace": fmt.Sprintf("%+v", e.StackTrace()), }) - } else { - log.Println(err) } } diff --git a/api/rekey.go b/api/rekey.go index 269086bb..3116cf74 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -28,24 +29,24 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + render.Error(w, errs.BadRequest("missing client certificate")) return } var body RekeyRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return } certChainPEM := certChainToPEM(certChain) @@ -55,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/render/render.go b/api/render/render.go new file mode 100644 index 00000000..9df4c791 --- /dev/null +++ b/api/render/render.go @@ -0,0 +1,122 @@ +// Package render implements functionality related to response rendering. +package render + +import ( + "bytes" + "encoding/json" + "net/http" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/api/log" +) + +// JSON is shorthand for JSONStatus(w, v, http.StatusOK). +func JSON(w http.ResponseWriter, v interface{}) { + JSONStatus(w, v, http.StatusOK) +} + +// JSONStatus marshals v into w. It additionally sets the status code of +// w to the given one. +// +// JSONStatus sets the Content-Type of w to application/json unless one is +// specified. +func JSONStatus(w http.ResponseWriter, v interface{}, status int) { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(v); err != nil { + panic(err) + } + + setContentTypeUnlessPresent(w, "application/json") + w.WriteHeader(status) + _, _ = b.WriteTo(w) + + log.EnabledResponse(w, v) +} + +// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK). +func ProtoJSON(w http.ResponseWriter, m proto.Message) { + ProtoJSONStatus(w, m, http.StatusOK) +} + +// ProtoJSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { + b, err := protojson.Marshal(m) + if err != nil { + panic(err) + } + + setContentTypeUnlessPresent(w, "application/json") + w.WriteHeader(status) + _, _ = w.Write(b) +} + +func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) { + const header = "Content-Type" + + h := w.Header() + if _, ok := h[header]; !ok { + h.Set(header, contentType) + } +} + +// RenderableError is the set of errors that implement the basic Render method. +// +// Errors that implement this interface will use their own Render method when +// being rendered into responses. +type RenderableError interface { + error + + Render(http.ResponseWriter) +} + +// Error marshals the JSON representation of err to w. In case err implements +// RenderableError its own Render method will be called instead. +func Error(w http.ResponseWriter, err error) { + log.Error(w, err) + + if e, ok := err.(RenderableError); ok { + e.Render(w) + + return + } + + JSONStatus(w, err, statusCodeFromError(err)) +} + +// StatusCodedError is the set of errors that implement the basic StatusCode +// function. +// +// Errors that implement this interface will use the code reported by StatusCode +// as the HTTP response code when being rendered by this package. +type StatusCodedError interface { + error + + StatusCode() int +} + +func statusCodeFromError(err error) (code int) { + code = http.StatusInternalServerError + + type causer interface { + Cause() error + } + + for err != nil { + if sc, ok := err.(StatusCodedError); ok { + code = sc.StatusCode() + + break + } + + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + + return +} diff --git a/api/render/render_test.go b/api/render/render_test.go new file mode 100644 index 00000000..06d092d3 --- /dev/null +++ b/api/render/render_test.go @@ -0,0 +1,115 @@ +package render + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smallstep/certificates/logging" +) + +func TestJSON(t *testing.T) { + rec := httptest.NewRecorder() + rw := logging.NewResponseLogger(rec) + + JSON(rw, map[string]interface{}{"foo": "bar"}) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String()) + + assert.Empty(t, rw.Fields()) +} + +func TestJSONPanics(t *testing.T) { + assert.Panics(t, func() { + JSON(httptest.NewRecorder(), make(chan struct{})) + }) +} + +type renderableError struct { + Code int `json:"-"` + Message string `json:"message"` +} + +func (err renderableError) Error() string { + return err.Message +} + +func (err renderableError) Render(w http.ResponseWriter) { + w.Header().Set("Content-Type", "something/custom") + + JSONStatus(w, err, err.Code) +} + +type statusedError struct { + Contents string +} + +func (err statusedError) Error() string { return err.Contents } + +func (statusedError) StatusCode() int { return 432 } + +func TestError(t *testing.T) { + cases := []struct { + err error + code int + body string + header string + }{ + 0: { + err: renderableError{532, "some string"}, + code: 532, + body: "{\"message\":\"some string\"}\n", + header: "something/custom", + }, + 1: { + err: statusedError{"123"}, + code: 432, + body: "{\"Contents\":\"123\"}\n", + header: "application/json", + }, + } + + for caseIndex := range cases { + kase := cases[caseIndex] + + t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { + rec := httptest.NewRecorder() + + Error(rec, kase.err) + + assert.Equal(t, kase.code, rec.Result().StatusCode) + assert.Equal(t, kase.body, rec.Body.String()) + assert.Equal(t, kase.header, rec.Header().Get("Content-Type")) + }) + } +} + +type causedError struct { + cause error +} + +func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) } +func (err causedError) Cause() error { return err.cause } + +func TestStatusCodeFromError(t *testing.T) { + cases := []struct { + err error + exp int + }{ + 0: {nil, http.StatusInternalServerError}, + 1: {io.EOF, http.StatusInternalServerError}, + 2: {statusedError{"123"}, 432}, + 3: {causedError{statusedError{"432"}}, 432}, + } + + for caseIndex, kase := range cases { + assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex) + } +} diff --git a/api/renew.go b/api/renew.go index 408d91a3..9c4bff32 100644 --- a/api/renew.go +++ b/api/renew.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -18,13 +19,13 @@ const ( func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { cert, err := h.getPeerCertificate(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Renew(cert) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return } certChainPEM := certChainToPEM(certChain) @@ -34,7 +35,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/revoke.go b/api/revoke.go index 49822e6d..c9da2c18 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -51,12 +52,12 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -73,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { if len(body.OTT) > 0 { logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT @@ -82,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or client certificate")) + render.Error(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("serial number in client certificate different than body")) + render.Error(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. @@ -98,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } logRevoke(w, opts) - JSON(w, &RevokeResponse{Status: "ok"}) + render.JSON(w, &RevokeResponse{Status: "ok"}) } func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/sign.go b/api/sign.go index b2eef45d..b6bfcc8b 100644 --- a/api/sign.go +++ b/api/sign.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -51,13 +52,13 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -69,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { signOpts, err := h.Authority.AuthorizeSign(body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) @@ -84,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/ssh.go b/api/ssh.go index fc185d07..3b0de7c1 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -12,6 +12,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -252,19 +253,19 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -272,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -289,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } @@ -303,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -316,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } @@ -328,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, IdentityCertificate: identityCertificate, @@ -346,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHRoots(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -363,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHFederation is an HTTP handler that returns the federated SSH public keys @@ -371,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -388,7 +389,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHConfig is an HTTP handler that returns rendered templates for ssh clients @@ -396,17 +397,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } @@ -417,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: cfg.HostTemplates = ts default: - WriteError(w, errs.InternalServer("it should hot get here")) + render.Error(w, errs.InternalServer("it should hot get here")) return } - JSON(w, cfg) + render.JSON(w, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHCheckPrincipalResponse{ + render.JSON(w, &SSHCheckPrincipalResponse{ Exists: exists, }) } @@ -455,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHGetHostsResponse{ + render.JSON(w, &SSHGetHostsResponse{ Hosts: hosts, }) } @@ -467,21 +468,21 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHBastionResponse{ + render.JSON(w, &SSHBastionResponse{ Hostname: body.Hostname, Bastion: bastion, }) diff --git a/api/sshRekey.go b/api/sshRekey.go index b7581749..92278950 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -41,36 +42,37 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } @@ -80,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } - JSONStatus(w, &SSHRekeyResponse{ + render.JSONStatus(w, &SSHRekeyResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRenew.go b/api/sshRenew.go index b98466bf..78d16fa6 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -39,30 +40,31 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RenewSSH(ctx, oldCert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } @@ -72,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 2d2da1f7..a33082cd 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -6,6 +6,7 @@ import ( "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -50,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -71,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } logSSHRevoke(w, opts) - JSON(w, &SSHRevokeResponse{Status: "ok"}) + render.JSON(w, &SSHRevokeResponse{Status: "ok"}) } func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/utils.go b/api/utils.go deleted file mode 100644 index e3fcc9c4..00000000 --- a/api/utils.go +++ /dev/null @@ -1,57 +0,0 @@ -package api - -import ( - "encoding/json" - "net/http" - - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - "github.com/smallstep/certificates/api/log" -) - -// JSON writes the passed value into the http.ResponseWriter. -func JSON(w http.ResponseWriter, v interface{}) { - JSONStatus(w, v, http.StatusOK) -} - -// JSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func JSONStatus(w http.ResponseWriter, v interface{}, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(v); err != nil { - log.Error(w, err) - - return - } - - log.EnabledResponse(w, v) -} - -// ProtoJSON writes the passed value into the http.ResponseWriter. -func ProtoJSON(w http.ResponseWriter, m proto.Message) { - ProtoJSONStatus(w, m, http.StatusOK) -} - -// ProtoJSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - - b, err := protojson.Marshal(m) - if err != nil { - log.Error(w, err) - - return - } - - if _, err := w.Write(b); err != nil { - log.Error(w, err) - - return - } - - // log.EnabledResponse(w, v) -} diff --git a/api/utils_test.go b/api/utils_test.go deleted file mode 100644 index f5e1e1cb..00000000 --- a/api/utils_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package api - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/smallstep/certificates/logging" -) - -func TestJSON(t *testing.T) { - type args struct { - rw http.ResponseWriter - v interface{} - } - tests := []struct { - name string - args args - ok bool - }{ - {"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true}, - {"fail", args{httptest.NewRecorder(), make(chan int)}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rw := logging.NewResponseLogger(tt.args.rw) - JSON(rw, tt.args.v) - - rr, ok := tt.args.rw.(*httptest.ResponseRecorder) - if !ok { - t.Error("ResponseWriter does not implement *httptest.ResponseRecorder") - return - } - - fields := rw.Fields() - if tt.ok { - if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" { - t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body) - } - if len(fields) != 0 { - t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields) - } - } else { - if body := rr.Body.String(); body != "" { - t.Errorf("Unexpected body = %s, want empty string", body) - } - if len(fields) != 1 { - t.Errorf("ResponseLogger fields = %v, wants 1 element", fields) - } - } - }) - } -} diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 27c3ba6f..21a7229d 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -6,10 +6,12 @@ import ( "net/http" "github.com/go-chi/chi" - "github.com/smallstep/certificates/api" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) const ( @@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { provName := chi.URLParam(r, "provisionerName") eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !eabEnabled { - api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) return } ctx = context.WithValue(ctx, provisionerContextKey, prov) @@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder { // GetExternalAccountKeys writes the response for the EAB keys GET endpoint func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 43607c52..5e4b9c30 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -10,6 +10,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) @@ -85,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { adm, ok := h.auth.LoadAdminByID(id) if !ok { - api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } // GetAdmins returns a segment of admins associated with the authority. func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) + render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } - api.JSON(w, &GetAdminsResponse{ + render.JSON(w, &GetAdminsResponse{ Admins: admins, NextCursor: nextCursor, }) @@ -116,18 +117,18 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } p, err := h.auth.LoadProvisionerByName(body.Provisioner) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return } adm := &linkedca.Admin{ @@ -137,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // Store to authority collection. if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) + render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } - api.ProtoJSONStatus(w, adm, http.StatusCreated) + render.ProtoJSONStatus(w, adm, http.StatusCreated) } // DeleteAdmin deletes admin. @@ -149,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateAdmin updates an existing admin. func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -173,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 19025a9d..b57dd6eb 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" ) @@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request) func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { if !h.auth.IsAdminAPIEnabled() { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } @@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { - api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, + render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, "missing authorization header token")) return } adm, err := h.auth.AuthorizeAdminToken(r, tok) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 2106733d..1cad62dd 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -4,10 +4,12 @@ import ( "net/http" "github.com/go-chi/chi" + "go.step.sm/linkedca" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" @@ -33,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.ProtoJSON(w, prov) + render.ProtoJSON(w, prov) } // GetProvisioners returns the given segment of provisioners associated with the authority. func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } p, next, err := h.auth.GetProvisioners(cursor, limit) if err != nil { - api.WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - api.JSON(w, &GetProvisionersResponse{ + render.JSON(w, &GetProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -75,21 +77,21 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // TODO: Validate inputs if err := authority.ValidateClaims(prov.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) + render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } - api.ProtoJSONStatus(w, prov, http.StatusCreated) + render.ProtoJSONStatus(w, prov, http.StatusCreated) } // DeleteProvisioner deletes a provisioner. @@ -103,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) + render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateProvisioner updates an existing prov. func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } name := chi.URLParam(r, "name") _old, err := h.auth.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) return } if nu.Id != old.Id { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner ID")) return } if nu.Type != old.Type { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) + render.Error(w, admin.NewErrorISE("cannot change provisioner type")) return } if nu.AuthorityId != old.AuthorityId { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID")) return } if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt")) return } if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt")) return } // TODO: Validate inputs if err := authority.ValidateClaims(nu.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.ProtoJSON(w, nu) + render.ProtoJSON(w, nu) } diff --git a/authority/admin/errors.go b/authority/admin/errors.go index 217227ca..baa32dd9 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -3,13 +3,10 @@ package admin import ( "encoding/json" "fmt" - "log" "net/http" - "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the Admin problem. @@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) { return string(b), nil } -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err *Error) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(err.StatusCode()) - - err.Message = err.Err.Error() - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err.Err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.Err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } +// Render implements render.RenderableError for Error. +func (e *Error) Render(w http.ResponseWriter) { + e.Message = e.Err.Error() - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Println(err) - } + render.JSONStatus(w, e, e.StatusCode()) } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index b631741a..81e542c5 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/base64" + "errors" "fmt" "net/http" "reflect" @@ -16,16 +17,18 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) var testAudiences = provisioner.Audiences{ @@ -310,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) { p, err := tc.auth.authorizeToken(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -396,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) { if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -481,8 +484,8 @@ func TestAuthority_authorizeSign(t *testing.T) { got, err := tc.auth.authorizeSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -740,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, got) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -853,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { err := tc.auth.authorizeRenew(tc.cert) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1001,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1118,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1218,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) { if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1311,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index bc4e97e0..49ac9468 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -3,13 +3,14 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/api/render" ) func TestACME_Getters(t *testing.T) { @@ -114,7 +115,7 @@ func TestACME_AuthorizeRenew(t *testing.T) { NotAfter: now.Add(time.Hour), }, code: http.StatusUnauthorized, - err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -133,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -168,8 +169,8 @@ func TestACME_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -192,7 +193,7 @@ func TestACME_AuthorizeSign(t *testing.T) { assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 559a48f1..1b7efa7c 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "encoding/hex" "encoding/pem" + "errors" "fmt" "net" "net/http" @@ -17,10 +18,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestAWS_Getters(t *testing.T) { @@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -668,8 +669,8 @@ func TestAWS_AuthorizeSign(t *testing.T) { t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) @@ -698,7 +699,7 @@ func TestAWS_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -802,8 +803,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -860,8 +861,8 @@ func TestAWS_AuthorizeRenew(t *testing.T) { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 40bb4698..8002563c 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "net/http" "net/http/httptest" @@ -15,10 +16,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestAzure_Getters(t *testing.T) { @@ -335,8 +336,8 @@ func TestAzure_authorizeToken(t *testing.T) { tc := tt(t) if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -497,8 +498,8 @@ func TestAzure_AuthorizeSign(t *testing.T) { t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) @@ -527,7 +528,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -572,8 +573,8 @@ func TestAzure_AuthorizeRenew(t *testing.T) { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -668,8 +669,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index b8c437c3..4ac42bff 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "net/http" "net/http/httptest" @@ -16,10 +17,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestGCP_Getters(t *testing.T) { @@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -540,8 +541,8 @@ func TestGCP_AuthorizeSign(t *testing.T) { t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) @@ -570,7 +571,7 @@ func TestGCP_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -677,8 +678,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -734,7 +735,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index dde2f836..215d9c84 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -6,15 +6,17 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" + "fmt" "net/http" "strings" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestJWK_Getters(t *testing.T) { @@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -288,8 +290,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -315,7 +317,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -361,8 +363,8 @@ func TestJWK_AuthorizeRenew(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -455,8 +457,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -620,8 +622,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 378d4471..b1aa3b55 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -3,14 +3,16 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestK8sSA_Getters(t *testing.T) { @@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -202,7 +204,7 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { NotAfter: now.Add(time.Hour), }, code: http.StatusUnauthorized, - err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -221,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -270,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -294,7 +296,7 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } @@ -326,7 +328,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), + err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { @@ -357,8 +359,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -378,7 +380,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index c1a94b1d..18b568a7 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -6,16 +6,17 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" "fmt" "net/http" "strings" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func Test_openIDConfiguration_Validate(t *testing.T) { @@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else { @@ -317,8 +318,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -340,7 +341,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) { case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -402,8 +403,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -448,8 +449,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -604,8 +605,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -672,8 +673,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 330d1b57..9678a20b 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -2,13 +2,14 @@ package provisioner import ( "context" + "errors" "net/http" "testing" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestType_String(t *testing.T) { @@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) { default: t.Errorf("unexpected method %s", tt.method) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized) assert.Equals(t, err.Error(), msg) }) diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index b548fe71..13294866 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -5,16 +5,19 @@ import ( "crypto" "crypto/rand" "encoding/base64" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" - "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestSSHPOP_Getters(t *testing.T) { @@ -215,8 +218,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -286,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -367,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { tc := tt(t) if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -449,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { tc := tt(t) if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -464,7 +467,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } assert.Equals(t, tc.cert.Nonce, cert.Nonce) diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 7932d045..22dd8541 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -3,16 +3,18 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestX5C_Getters(t *testing.T) { @@ -71,7 +73,7 @@ func TestX5C_Init(t *testing.T) { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, - err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), + err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { @@ -122,7 +124,7 @@ M46l92gdOozT // check the number of certificates in the pool. numCerts := len(p.rootPool.Subjects()) if numCerts != 2 { - return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) + return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts) } return nil }, @@ -387,8 +389,8 @@ lgsqsR63is+0YQ== tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -458,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -490,7 +492,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -541,8 +543,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -572,7 +574,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) { return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -591,8 +593,8 @@ func TestX5C_AuthorizeRenew(t *testing.T) { NotAfter: now.Add(time.Hour), }); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -631,7 +633,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), + err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { @@ -752,7 +754,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -787,7 +789,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 3975031b..81dc38bf 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,13 +1,13 @@ package authority import ( + "errors" "net/http" "testing" - "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/errs" ) func TestGetEncryptedKey(t *testing.T) { @@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) { ek, err := tc.a.GetEncryptedKey(tc.kid) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) { ps, next, err := tc.a.GetProvisioners("", 0) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/root_test.go b/authority/root_test.go index 6e5f1932..a1b08fac 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -2,14 +2,15 @@ package authority import ( "crypto/x509" + "errors" "net/http" "reflect" "testing" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/pemutil" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestRoot(t *testing.T) { @@ -31,7 +32,7 @@ func TestRoot(t *testing.T) { crt, err := a.Root(tc.sum) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index c299b347..ce840fe1 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -7,21 +7,22 @@ import ( "crypto/rand" "crypto/x509" "encoding/base64" + "errors" "fmt" "net/http" "reflect" "testing" "time" - "github.com/pkg/errors" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/sshutil" + "golang.org/x/crypto/ssh" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/templates" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/sshutil" - "golang.org/x/crypto/ssh" ) type sshTestModifier ssh.Certificate @@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { - _, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + _, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) @@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) { hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) { cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/tls_test.go b/authority/tls_test.go index 6ccf02ca..e199e0c5 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -11,24 +11,26 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/pem" + "errors" "fmt" "net/http" "reflect" "testing" "time" - "github.com/smallstep/certificates/cas/softcas" + "gopkg.in/square/go-jose.v2/jwt" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" - "gopkg.in/square/go-jose.v2/jwt" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/cas/softcas" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) var ( @@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) { func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { - return nil, errors.Wrap(err, "error marshaling public key") + return nil, fmt.Errorf("error marshaling public key: %w", err) } info := struct { Algorithm pkix.AlgorithmIdentifier SubjectPublicKey asn1.BitString }{} if _, err = asn1.Unmarshal(b, &info); err != nil { - return nil, errors.Wrap(err, "error unmarshaling public key") + return nil, fmt.Errorf("error unmarshaling public key: %w", err) } hash := sha1.Sum(info.SubjectPublicKey.Bytes) return hash[:], nil @@ -661,8 +663,8 @@ ZYtQ9Ot36qc= if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -860,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1067,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1456,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) if err := tc.auth.Revoke(ctx, tc.opts); err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index ad5f2116..f17a2f7a 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -12,12 +12,13 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/pemutil" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" - "github.com/smallstep/certificates/api" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/pemutil" + "github.com/smallstep/certificates/api/render" ) func TestNewACMEClient(t *testing.T) { @@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header switch { case i == 0: - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ case i == 1: w.Header().Set("Replay-Nonce", "abc123") - api.JSONStatus(w, []byte{}, 200) + render.JSONStatus(w, []byte{}, 200) i++ default: w.Header().Set("Location", accLocation) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) @@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) }) if nonce, err := ac.GetNonce(); err != nil { @@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) { assert.Equals(t, hdr.KeyID, ac.kid) } - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { @@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, norb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.NewOrder(norb); err != nil { @@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetOrder(url); err != nil { @@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetAuthz(url); err != nil { @@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetChallenge(url); err != nil { @@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { assert.Equals(t, payload, []byte("{}")) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.ValidateChallenge(url); err != nil { @@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, frb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.FinalizeOrder(url, csr); err != nil { @@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := tc.client.GetAccountOrders(); err != nil { @@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { if tc.certBytes != nil { w.Write(tc.certBytes) } else { - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 0e16bd7d..2332b4d4 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -14,11 +14,14 @@ import ( "time" "github.com/pkg/errors" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/randutil" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/randutil" ) func newLocalListener() net.Listener { @@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/version" { - api.JSON(w, api.VersionResponse{ + render.JSON(w, api.VersionResponse{ Version: "test", RequireClientAuthentication: true, }) @@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 if !isMTLS { - api.WriteError(w, errs.Unauthorized("missing peer certificate")) + render.Error(w, errs.Unauthorized("missing peer certificate")) } else { next.ServeHTTP(w, r) } diff --git a/ca/client_test.go b/ca/client_test.go index 4628d19b..48aa1488 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "net/http" "net/http/httptest" @@ -16,14 +17,13 @@ import ( "testing" "time" - "github.com/pkg/errors" - "golang.org/x/crypto/ssh" - "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -182,7 +182,7 @@ func TestClient_Version(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() @@ -232,7 +232,7 @@ func TestClient_Health(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() @@ -290,7 +290,7 @@ func TestClient_Root(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) @@ -360,7 +360,7 @@ func TestClient_Sign(t *testing.T) { if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -371,7 +371,7 @@ func TestClient_Sign(t *testing.T) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) @@ -432,7 +432,7 @@ func TestClient_Revoke(t *testing.T) { if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -443,7 +443,7 @@ func TestClient_Revoke(t *testing.T) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) @@ -503,7 +503,7 @@ func TestClient_Renew(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) @@ -519,8 +519,8 @@ func TestClient_Renew(t *testing.T) { t.Errorf("Client.Renew() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -568,9 +568,9 @@ func TestClient_RenewWithToken(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Header.Get("Authorization") != "Bearer token" { - api.JSONStatus(w, errs.InternalServer("force"), 500) + render.JSONStatus(w, errs.InternalServer("force"), 500) } else { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) } }) @@ -587,8 +587,8 @@ func TestClient_RenewWithToken(t *testing.T) { t.Errorf("Client.RenewWithToken() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -640,7 +640,7 @@ func TestClient_Rekey(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) @@ -656,8 +656,8 @@ func TestClient_Rekey(t *testing.T) { t.Errorf("Client.Renew() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -705,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) { if req.RequestURI != tt.expectedURI { t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) @@ -762,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) @@ -777,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) { t.Errorf("Client.ProvisionerKey() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -821,7 +821,7 @@ func TestClient_Roots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() @@ -836,8 +836,8 @@ func TestClient_Roots(t *testing.T) { if got != nil { t.Errorf("Client.Roots() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -879,7 +879,7 @@ func TestClient_Federation(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() @@ -894,8 +894,8 @@ func TestClient_Federation(t *testing.T) { if got != nil { t.Errorf("Client.Federation() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -941,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() @@ -956,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) { if got != nil { t.Errorf("Client.SSHKeys() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -1041,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) { } tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() @@ -1102,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) @@ -1118,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) { t.Errorf("Client.SSHBastion() = %v, want nil", got) } if tt.responseCode != 200 { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) } diff --git a/errs/error.go b/errs/error.go index 60da9e1f..c42e342d 100644 --- a/errs/error.go +++ b/errs/error.go @@ -6,17 +6,10 @@ import ( "net/http" "github.com/pkg/errors" -) - -// StatusCoder interface is used by errors that returns the HTTP response code. -type StatusCoder interface { - StatusCode() int -} -// StackTracer must be by those errors that return an stack trace. -type StackTracer interface { - StackTrace() errors.StackTrace -} + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" +) // Option modifies the Error type. type Option func(e *Error) error @@ -257,7 +250,7 @@ func NewError(status int, err error, format string, args ...interface{}) error { return err } msg := fmt.Sprintf(format, args...) - if _, ok := err.(StackTracer); !ok { + if _, ok := err.(log.StackTracedError); !ok { err = errors.Wrap(err, msg) } return &Error{ @@ -275,11 +268,11 @@ func NewErr(status int, err error, opts ...Option) error { ok bool ) if e, ok = err.(*Error); !ok { - if sc, ok := err.(StatusCoder); ok { + if sc, ok := err.(render.StatusCodedError); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { cause := errors.Cause(err) - if sc, ok := cause.(StatusCoder); ok { + if sc, ok := cause.(render.StatusCodedError); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { e = &Error{Status: status, Err: err} From 055e75f3941f423acf87b95d60cb2a8252fade35 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Wed, 30 Mar 2022 15:48:42 -0700 Subject: [PATCH 10/20] Progress? --- authority/authority.go | 26 ++++++++++++++++++-------- ca/ca.go | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index 50025cce..b6829861 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -80,6 +80,14 @@ type Authority struct { adminMutex sync.RWMutex } +type AuthorityInfo struct { + StartTime time.Time + RootX509Certs []*x509.Certificate + SSHCAUserCerts []ssh.PublicKey + SSHCAHostCerts []ssh.PublicKey +} + + // New creates and initiates a new Authority type. func New(cfg *config.Config, opts ...Option) (*Authority, error) { err := cfg.Validate() @@ -311,7 +319,6 @@ func (a *Authority) init() error { for _, crt := range a.rootX509Certs { sum := sha256.Sum256(crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), crt) - log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) } a.rootX509CertPool = x509.NewCertPool() @@ -540,13 +547,6 @@ func (a *Authority) init() error { a.templates.Data["Step"] = tmplVars } - if tmplVars.SSH.HostKey != nil { - log.Printf("SSH Host CA Key: %s\n", ssh.MarshalAuthorizedKey(tmplVars.SSH.HostKey)) - } - if tmplVars.SSH.UserKey != nil { - log.Printf("SSH User CA Key: %s\n", ssh.MarshalAuthorizedKey(tmplVars.SSH.UserKey)) - } - // JWT numeric dates are seconds. a.startTime = time.Now().Truncate(time.Second) // Set flag indicating that initialization has been completed, and should @@ -567,6 +567,16 @@ func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } +func (a *Authority) GetAuthorityInfo() *AuthorityInfo { + return &AuthorityInfo{ + StartTime: a.startTime, + RootX509Certs: a.rootX509Certs, + SSHCAUserCerts: a.sshCAUserCerts, + SSHCAHostCerts: a.sshCAHostCerts, + } + +} + // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has // been enabled. func (a *Authority) IsAdminAPIEnabled() bool { diff --git a/ca/ca.go b/ca/ca.go index 41f48483..223d2470 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -3,6 +3,8 @@ package ca import ( "crypto/tls" "crypto/x509" + "crypto/sha256" + "encoding/hex" "fmt" "log" "net/http" @@ -297,6 +299,18 @@ func (ca *CA) Run() error { errs := make(chan error, 1) if !ca.opts.quiet { + authorityInfo := ca.auth.GetAuthorityInfo() + log.Printf("Address: %s", ca.config.Address) + for _, crt := range authorityInfo.RootX509Certs { + sum := sha256.Sum256(crt.Raw) + log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) + } + if ca.config.SSH != nil { + log.Printf("SSH Host CA Key: %s\n", ca.config.SSH.HostKey) + } + if ca.config.SSH != nil { + log.Printf("SSH User CA Key: %s\n", ca.config.SSH.UserKey) + } log.Printf("Documentation: https://u.step.sm/docs/ca") log.Printf("Community Discord: https://u.step.sm/discord") log.Printf("Config File: %s", ca.opts.configFile) From 90cb6315b187a751cc32dcd6c36f1476302e8ddf Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Wed, 30 Mar 2022 16:05:26 -0700 Subject: [PATCH 11/20] Progress. --- authority/authority.go | 16 ++++++++++------ ca/ca.go | 8 ++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index b6829861..f2b8b983 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -83,8 +83,8 @@ type Authority struct { type AuthorityInfo struct { StartTime time.Time RootX509Certs []*x509.Certificate - SSHCAUserCerts []ssh.PublicKey - SSHCAHostCerts []ssh.PublicKey + SSHCAUserPublicKey []byte + SSHCAHostPublicKey []byte } @@ -568,13 +568,17 @@ func (a *Authority) GetAdminDatabase() admin.DB { } func (a *Authority) GetAuthorityInfo() *AuthorityInfo { - return &AuthorityInfo{ + ai := &AuthorityInfo{ StartTime: a.startTime, RootX509Certs: a.rootX509Certs, - SSHCAUserCerts: a.sshCAUserCerts, - SSHCAHostCerts: a.sshCAHostCerts, } - + if a.sshCAUserCertSignKey != nil { + ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey()) + } + if a.sshCAHostCertSignKey != nil { + ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey()) + } + return ai } // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has diff --git a/ca/ca.go b/ca/ca.go index 223d2470..0e7f3dbb 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -305,11 +305,11 @@ func (ca *CA) Run() error { sum := sha256.Sum256(crt.Raw) log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) } - if ca.config.SSH != nil { - log.Printf("SSH Host CA Key: %s\n", ca.config.SSH.HostKey) + if authorityInfo.SSHCAHostPublicKey != nil { + log.Printf("SSH Host CA Key: %s\n", authorityInfo.SSHCAHostPublicKey) } - if ca.config.SSH != nil { - log.Printf("SSH User CA Key: %s\n", ca.config.SSH.UserKey) + if authorityInfo.SSHCAUserPublicKey != nil { + log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey) } log.Printf("Documentation: https://u.step.sm/docs/ca") log.Printf("Community Discord: https://u.step.sm/discord") From a13e58e3407dc169739e42e0b7aa60d01f071175 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Wed, 30 Mar 2022 16:07:16 -0700 Subject: [PATCH 12/20] Update GetAuthorityInfo -> GetInfo --- authority/authority.go | 4 ++-- ca/ca.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index f2b8b983..3b26dbc5 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -567,8 +567,8 @@ func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } -func (a *Authority) GetAuthorityInfo() *AuthorityInfo { - ai := &AuthorityInfo{ +func (a *Authority) GetInfo() AuthorityInfo { + ai := AuthorityInfo{ StartTime: a.startTime, RootX509Certs: a.rootX509Certs, } diff --git a/ca/ca.go b/ca/ca.go index 0e7f3dbb..bf967aed 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -299,7 +299,7 @@ func (ca *CA) Run() error { errs := make(chan error, 1) if !ca.opts.quiet { - authorityInfo := ca.auth.GetAuthorityInfo() + authorityInfo := ca.auth.GetInfo() log.Printf("Address: %s", ca.config.Address) for _, crt := range authorityInfo.RootX509Certs { sum := sha256.Sum256(crt.Raw) From 1ba1584c7a2abb32b073d1e13a978868a0e91b2a Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Wed, 30 Mar 2022 16:08:10 -0700 Subject: [PATCH 13/20] Formatted. --- authority/authority.go | 11 +++++------ ca/ca.go | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index 3b26dbc5..c5f8c3a6 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -81,13 +81,12 @@ type Authority struct { } type AuthorityInfo struct { - StartTime time.Time - RootX509Certs []*x509.Certificate - SSHCAUserPublicKey []byte - SSHCAHostPublicKey []byte + StartTime time.Time + RootX509Certs []*x509.Certificate + SSHCAUserPublicKey []byte + SSHCAHostPublicKey []byte } - // New creates and initiates a new Authority type. func New(cfg *config.Config, opts ...Option) (*Authority, error) { err := cfg.Validate() @@ -569,7 +568,7 @@ func (a *Authority) GetAdminDatabase() admin.DB { func (a *Authority) GetInfo() AuthorityInfo { ai := AuthorityInfo{ - StartTime: a.startTime, + StartTime: a.startTime, RootX509Certs: a.rootX509Certs, } if a.sshCAUserCertSignKey != nil { diff --git a/ca/ca.go b/ca/ca.go index bf967aed..21b64ee7 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -1,9 +1,9 @@ package ca import ( + "crypto/sha256" "crypto/tls" "crypto/x509" - "crypto/sha256" "encoding/hex" "fmt" "log" @@ -302,8 +302,8 @@ func (ca *CA) Run() error { authorityInfo := ca.auth.GetInfo() log.Printf("Address: %s", ca.config.Address) for _, crt := range authorityInfo.RootX509Certs { - sum := sha256.Sum256(crt.Raw) - log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) + sum := sha256.Sum256(crt.Raw) + log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) } if authorityInfo.SSHCAHostPublicKey != nil { log.Printf("SSH Host CA Key: %s\n", authorityInfo.SSHCAHostPublicKey) From f5bf46b950422344837d191639928b1a1584c325 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 30 Mar 2022 18:24:17 -0700 Subject: [PATCH 14/20] Upgrade go.step.sm/crypto --- CHANGELOG.md | 1 + go.mod | 3 ++- go.sum | 7 +++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73c338f8..49e4b15e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.3] - DATE ### Added - Added support for renew after expiry using the claim `allowRenewAfterExpiry`. +- Added support for `extraNames` in X.509 templates. ### Changed - Made SCEP CA URL paths dynamic - Support two latest versions of Go (1.17, 1.18) diff --git a/go.mod b/go.mod index d8b47e5f..17ea33fe 100644 --- a/go.mod +++ b/go.mod @@ -33,10 +33,11 @@ require ( github.com/slackhq/nebula v1.5.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/nosql v0.4.0 + github.com/stretchr/testify v1.7.0 github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 - go.step.sm/crypto v0.15.3 + go.step.sm/crypto v0.16.1 go.step.sm/linkedca v0.11.0 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd diff --git a/go.sum b/go.sum index b7081349..e7ddd660 100644 --- a/go.sum +++ b/go.sum @@ -468,9 +468,11 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -706,8 +708,8 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= -go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10= -go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= +go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk= +go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= go.step.sm/linkedca v0.11.0 h1:jkG5XDQz9VSz2PH+cGjDvJTwiIziN0SWExTnicWpb8o= go.step.sm/linkedca v0.11.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -1195,6 +1197,7 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= From 7ebb2e4c74f4f0510198baf1a0a1301b7e105e84 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Mon, 4 Apr 2022 11:14:04 -0700 Subject: [PATCH 15/20] Update ca/ca.go Co-authored-by: Herman Slatman --- ca/ca.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 21b64ee7..185fb72e 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -302,8 +302,7 @@ func (ca *CA) Run() error { authorityInfo := ca.auth.GetInfo() log.Printf("Address: %s", ca.config.Address) for _, crt := range authorityInfo.RootX509Certs { - sum := sha256.Sum256(crt.Raw) - log.Printf("X.509 Root Fingerprint: %s", hex.EncodeToString(sum[:])) + log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt)) } if authorityInfo.SSHCAHostPublicKey != nil { log.Printf("SSH Host CA Key: %s\n", authorityInfo.SSHCAHostPublicKey) From 43f2c655b909b02e30b5c4aaff898cfe45525a60 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Mon, 4 Apr 2022 12:16:37 -0700 Subject: [PATCH 16/20] More info on startup --- authority/authority.go | 2 ++ ca/ca.go | 15 +++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index c5f8c3a6..8c5eb9c4 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -85,6 +85,7 @@ type AuthorityInfo struct { RootX509Certs []*x509.Certificate SSHCAUserPublicKey []byte SSHCAHostPublicKey []byte + DNSNames []string } // New creates and initiates a new Authority type. @@ -570,6 +571,7 @@ func (a *Authority) GetInfo() AuthorityInfo { ai := AuthorityInfo{ StartTime: a.startTime, RootX509Certs: a.rootX509Certs, + DNSNames: a.config.DNSNames, } if a.sshCAUserCertSignKey != nil { ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey()) diff --git a/ca/ca.go b/ca/ca.go index 185fb72e..89813d64 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -1,15 +1,14 @@ package ca import ( - "crypto/sha256" "crypto/tls" "crypto/x509" - "encoding/hex" "fmt" "log" "net/http" "net/url" "reflect" + "strings" "sync" "github.com/go-chi/chi" @@ -28,6 +27,7 @@ import ( scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" "github.com/smallstep/nosql" + "go.step.sm/crypto/x509util" ) type options struct { @@ -300,12 +300,19 @@ func (ca *CA) Run() error { if !ca.opts.quiet { authorityInfo := ca.auth.GetInfo() - log.Printf("Address: %s", ca.config.Address) + log.Printf("Welcome to step-ca.") + log.Printf("The primary server URL is https://%s%s", + authorityInfo.DNSNames[0], + ca.config.Address[strings.LastIndex(ca.config.Address, ":"):]) + if len(authorityInfo.DNSNames) > 1 { + log.Printf("Additional configured hostnames: %s", + strings.Join(authorityInfo.DNSNames[1:], ", ")) + } for _, crt := range authorityInfo.RootX509Certs { log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt)) } if authorityInfo.SSHCAHostPublicKey != nil { - log.Printf("SSH Host CA Key: %s\n", authorityInfo.SSHCAHostPublicKey) + log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey) } if authorityInfo.SSHCAUserPublicKey != nil { log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey) From acc75bc679f3f13f6424343133559f88ba8fa868 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Mon, 4 Apr 2022 12:29:27 -0700 Subject: [PATCH 17/20] Add context name to startup info --- ca/ca.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index e509f74d..fce80e00 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -27,6 +27,7 @@ import ( scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" "github.com/smallstep/nosql" + "go.step.sm/cli-utils/step" "go.step.sm/crypto/x509util" ) @@ -301,6 +302,10 @@ func (ca *CA) Run() error { if !ca.opts.quiet { authorityInfo := ca.auth.GetInfo() log.Printf("Welcome to step-ca.") + log.Printf("Documentation: https://u.step.sm/docs/ca") + log.Printf("Community Discord: https://u.step.sm/discord") + log.Printf("Current context: %s", step.Contexts().GetCurrent().Name) + log.Printf("Config file: %s", ca.opts.configFile) log.Printf("The primary server URL is https://%s%s", authorityInfo.DNSNames[0], ca.config.Address[strings.LastIndex(ca.config.Address, ":"):]) @@ -317,9 +322,6 @@ func (ca *CA) Run() error { if authorityInfo.SSHCAUserPublicKey != nil { log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey) } - log.Printf("Documentation: https://u.step.sm/docs/ca") - log.Printf("Community Discord: https://u.step.sm/discord") - log.Printf("Config File: %s", ca.opts.configFile) } if ca.insecureSrv != nil { From 150eee70df9a6381bef1713fc0d38d42d46e074f Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Tue, 5 Apr 2022 10:59:25 -0700 Subject: [PATCH 18/20] Updates based on Herman's feedback --- authority/authority.go | 6 +++--- ca/ca.go | 10 +++++++--- commands/app.go | 1 + 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index 8c5eb9c4..b6071060 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -80,7 +80,7 @@ type Authority struct { adminMutex sync.RWMutex } -type AuthorityInfo struct { +type Info struct { StartTime time.Time RootX509Certs []*x509.Certificate SSHCAUserPublicKey []byte @@ -567,8 +567,8 @@ func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } -func (a *Authority) GetInfo() AuthorityInfo { - ai := AuthorityInfo{ +func (a *Authority) GetInfo() Info { + ai := Info{ StartTime: a.startTime, RootX509Certs: a.rootX509Certs, DNSNames: a.config.DNSNames, diff --git a/ca/ca.go b/ca/ca.go index fce80e00..0d4f1578 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -301,14 +301,18 @@ func (ca *CA) Run() error { if !ca.opts.quiet { authorityInfo := ca.auth.GetInfo() - log.Printf("Welcome to step-ca.") + log.Printf("Starting %s", step.Version()) log.Printf("Documentation: https://u.step.sm/docs/ca") log.Printf("Community Discord: https://u.step.sm/discord") - log.Printf("Current context: %s", step.Contexts().GetCurrent().Name) + if step.Contexts().GetCurrent() != nil { + log.Printf("Current context: %s", step.Contexts().GetCurrent().Name) + } log.Printf("Config file: %s", ca.opts.configFile) - log.Printf("The primary server URL is https://%s%s", + baseURL := fmt.Sprintf("https://%s%s", authorityInfo.DNSNames[0], ca.config.Address[strings.LastIndex(ca.config.Address, ":"):]) + log.Printf("The primary server URL is %s", baseURL) + log.Printf("Root certificates are available at %s/roots.pem", baseURL) if len(authorityInfo.DNSNames) > 1 { log.Printf("Additional configured hostnames: %s", strings.Join(authorityInfo.DNSNames[1:], ", ")) diff --git a/commands/app.go b/commands/app.go index 6297581f..c3eacd02 100644 --- a/commands/app.go +++ b/commands/app.go @@ -61,6 +61,7 @@ certificate issuer private key used in the RA mode.`, cli.BoolFlag{ Name: "quiet", Usage: "disable startup information", + EnvVar: "STEP_CA_QUIET", }, cli.StringFlag{ Name: "context", From 2e61e01f41a646ae89a1f660ab6643f59f4ed3d9 Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Tue, 5 Apr 2022 10:59:35 -0700 Subject: [PATCH 19/20] Linted. --- commands/app.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/commands/app.go b/commands/app.go index c3eacd02..265610f2 100644 --- a/commands/app.go +++ b/commands/app.go @@ -59,8 +59,8 @@ certificate issuer private key used in the RA mode.`, EnvVar: "STEP_CA_TOKEN", }, cli.BoolFlag{ - Name: "quiet", - Usage: "disable startup information", + Name: "quiet", + Usage: "disable startup information", EnvVar: "STEP_CA_QUIET", }, cli.StringFlag{ From 479c6d2bf563fcc0073779e5f736823d135a3fe2 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 7 Apr 2022 12:37:34 +0200 Subject: [PATCH 20/20] Fix ACME IPv6 HTTP-01 challenges Fixes #890 --- acme/challenge.go | 13 ++++++++++++- acme/challenge_test.go | 36 ++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/acme/challenge.go b/acme/challenge.go index 0e1994e4..9f08bae5 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, } func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { - u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} + u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} resp, err := vo.HTTPGet(u.String()) if err != nil { @@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return nil } +// http01ChallengeHost checks if a Challenge value is an IPv6 address +// and adds square brackets if that's the case, so that it can be used +// as a hostname. Returns the original Challenge value as the host to +// use in other cases. +func http01ChallengeHost(value string) string { + if ip := net.ParseIP(value); ip != nil && ip.To4() == nil { + value = "[" + value + "]" + } + return value +} + func tlsAlert(err error) uint8 { var opErr *net.OpError if errors.As(err, &opErr) { diff --git a/acme/challenge_test.go b/acme/challenge_test.go index d8ce4d76..c05b25e7 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -13,6 +13,7 @@ import ( "encoding/asn1" "encoding/base64" "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -23,9 +24,9 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" ) func Test_storeError(t *testing.T) { @@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) { }) } } + +func Test_http01ChallengeHost(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + { + name: "dns", + value: "www.example.com", + want: "www.example.com", + }, + { + name: "ipv4", + value: "127.0.0.1", + want: "127.0.0.1", + }, + { + name: "ipv6", + value: "::1", + want: "[::1]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := http01ChallengeHost(tt.value); got != tt.want { + t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want) + } + }) + } +}