diff --git a/api/api.go b/api/api.go index fda27c42..9c2f1f31 100644 --- a/api/api.go +++ b/api/api.go @@ -40,6 +40,7 @@ type Authority interface { Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) + RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error) diff --git a/api/api_test.go b/api/api_test.go index abbbbd5b..e24751b3 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -192,6 +192,7 @@ type mockAuthority struct { sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) + renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByName func(name string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) @@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.renewContext != nil { + return m.renewContext(ctx, oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { if m.rekey != nil { return m.rekey(oldcert, pk) diff --git a/api/renew.go b/api/renew.go index 6e9f680f..1b9ed95f 100644 --- a/api/renew.go +++ b/api/renew.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/errs" ) @@ -17,14 +18,22 @@ const ( // Renew uses the information of certificate in the TLS connection to create a // new one. func Renew(w http.ResponseWriter, r *http.Request) { - cert, err := getPeerCertificate(r) + ctx := r.Context() + + // Get the leaf certificate from the peer or the token. + cert, token, err := getPeerCertificate(r) if err != nil { render.Error(w, err) return } - a := mustAuthority(r.Context()) - certChain, err := a.Renew(cert) + // The token can be used by RAs to renew a certificate. + if token != "" { + ctx = authority.NewTokenContext(ctx, token) + } + + a := mustAuthority(ctx) + certChain, err := a.RenewContext(ctx, cert, nil) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) { }, http.StatusCreated) } -func getPeerCertificate(r *http.Request) (*x509.Certificate, error) { +func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { - return r.TLS.PeerCertificates[0], nil + return r.TLS.PeerCertificates[0], "", nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { ctx := r.Context() - return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) + peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) + return peer, parts[1], err } } - return nil, errs.BadRequest("missing client certificate") + return nil, "", errs.BadRequest("missing client certificate") } diff --git a/authority/authorize.go b/authority/authorize.go index 44956cbd..1e50da89 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { // extra extension cannot be found, authorize the renewal by default. // // TODO(mariano): should we authorize by default? -func (a *Authority) authorizeRenew(cert *x509.Certificate) error { +func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} @@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) } } - if err := p.AuthorizeRenew(context.Background(), cert); err != nil { + if err := p.AuthorizeRenew(ctx, cert); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } return nil @@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509. } audiences := a.config.GetAudiences().Renew - if !matchesAudience(claims.Audience, audiences) { + if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) { return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 7dc22f3a..bec34fd6 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - err := tc.auth.authorizeRenew(tc.cert) + err := tc.auth.authorizeRenew(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError @@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { }) return nil })) + a4 := testAuthority(t) + a4.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"}, + RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, + }, nil + }, + } + t4, c4 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://ra.example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-ca-client/1.0", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", @@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { {"ok", a1, args{ctx, t1}, c1, false}, {"ok expired cert", a1, args{ctx, t2}, c2, false}, {"ok provisioner issuer", a1, args{ctx, t3}, c3, false}, + {"ok ra provisioner", a4, args{ctx, t4}, c4, false}, {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, {"fail token reuse", a1, args{ctx, t1}, nil, true}, {"fail token signature", a1, args{ctx, badSigner}, nil, true}, diff --git a/authority/provisioners.go b/authority/provisioners.go index bfa4eae5..d8a7b4d1 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa } } +// wrapRAProvisioner wraps the given provisioner with RA information. +func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner { + return &wrappedProvisioner{ + Interface: p, + raInfo: raInfo, + } +} + +// isRAProvisioner returns if the given provisioner is an RA provisioner. +func isRAProvisioner(p provisioner.Interface) bool { + if rap, ok := p.(raProvisioner); ok { + return rap.RAInfo() != nil + } + return false +} + // wrappedProvisioner implements raProvisioner and attProvisioner. type wrappedProvisioner struct { provisioner.Interface @@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr } if err == nil && data != nil && data.Provisioner != nil { if p, ok := a.provisioners.Load(data.Provisioner.ID); ok { + if data.RaInfo != nil { + return wrapRAProvisioner(p, data.RaInfo), nil + } return p, nil } } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 6ef62223..7901de6a 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) { }) } } + +func Test_wrapRAProvisioner(t *testing.T) { + type args struct { + p provisioner.Interface + raInfo *provisioner.RAInfo + } + tests := []struct { + name string + args args + want *wrappedProvisioner + }{ + {"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{ + Interface: &provisioner.JWK{Name: "jwt"}, + raInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_isRAProvisioner(t *testing.T) { + type args struct { + p provisioner.Interface + } + tests := []struct { + name string + args args + want bool + }{ + {"true", args{&wrappedProvisioner{ + Interface: &provisioner.JWK{Name: "jwt"}, + raInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, + }}, true}, + {"nil ra", args{&wrappedProvisioner{ + Interface: &provisioner.JWK{Name: "jwt"}, + }}, false}, + {"not ra", args{&provisioner.JWK{Name: "jwt"}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isRAProvisioner(tt.args.p); got != tt.want { + t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/tls.go b/authority/tls.go index b5d85074..11c61b9e 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -34,6 +34,19 @@ import ( "github.com/smallstep/nosql/database" ) +type tokenKey struct{} + +// NewTokenContext adds the given token to the context. +func NewTokenContext(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, tokenKey{}, token) +} + +// TokenFromContext returns the token from the given context. +func TokenFromContext(ctx context.Context) (token string, ok bool) { + token, ok = ctx.Value(tokenKey{}).(string) + return +} + // GetTLSOptions returns the tls options configured. func (a *Authority) GetTLSOptions() *config.TLSOptions { return a.config.TLS @@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error { return a.policyEngine.AreSANsAllowed(sans) } -// Renew creates a new Certificate identical to the old certificate, except -// with a validity window that begins 'now'. +// Renew creates a new Certificate identical to the old certificate, except with +// a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { - return a.Rekey(oldCert, nil) + return a.RenewContext(context.Background(), oldCert, nil) } -// Rekey is used for rekeying and renewing based on the public key. -// If the public key is 'nil' then it's assumed that the cert should be renewed -// using the existing public key. If the public key is not 'nil' then it's -// assumed that the cert should be rekeyed. +// Rekey is used for rekeying and renewing based on the public key. If the +// public key is 'nil' then it's assumed that the cert should be renewed using +// the existing public key. If the public key is not 'nil' then it's assumed +// that the cert should be rekeyed. +// // For both Rekey and Renew all other attributes of the new certificate should // match the old certificate. The exceptions are 'AuthorityKeyId' (which may // have changed), 'SubjectKeyId' (different in case of rekey), and // 'NotBefore/NotAfter' (the validity duration of the new certificate should be // equal to the old one, but starting 'now'). func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + return a.RenewContext(context.Background(), oldCert, pk) +} + +// RenewContext creates a new certificate identical to the old one, but it can +// optionally replace the public key with the given one. When running on RA +// mode, it can only renew a certificate using a renew token instead. +// +// For both rekey and renew operations, all other attributes of the new +// certificate should match the old certificate. The exceptions are +// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case +// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new +// certificate should be equal to the old one, but starting 'now'). +func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { isRekey := (pk != nil) - opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())} + opts := []errs.Option{ + errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()), + } // Check step provisioner extensions - if err := a.authorizeRenew(oldCert); err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) + if err := a.authorizeRenew(ctx, oldCert); err != nil { + return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } // Durations @@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil { var ee *errs.Error if errors.As(err, &ee) { - return nil, errs.ApplyOptions(ee, opts...) + return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...) } return nil, errs.InternalServerErr(err, errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()), @@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 ) } + // The token can optionally be in the context. If the CA is running in RA + // mode, this can be used to renew a certificate. + token, _ := TokenFromContext(ctx) + resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{ Template: newCert, Lifetime: lifetime, Backdate: backdate, + Token: token, }) if err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) + return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil { if !errors.Is(err, db.ErrNotImplemented) { - return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) + return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } } diff --git a/authority/tls_test.go b/authority/tls_test.go index 918adbdc..5d63b3dd 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) { return &renewTest{ auth: _a, cert: cert, - err: errors.New("authority.Rekey: error creating certificate"), + err: errors.New("error creating certificate"), code: http.StatusInternalServerError, }, nil }, "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) { return &renewTest{ auth: aa, cert: cert, - err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"), + err: errors.New("authority.authorizeRenew: not authorized"), code: http.StatusUnauthorized, }, nil }, @@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) { return &renewTest{ auth: _a, cert: cert, - err: errors.New("authority.Rekey: error creating certificate"), + err: errors.New("error creating certificate"), code: http.StatusInternalServerError, }, nil }, "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, diff --git a/cas/apiv1/requests.go b/cas/apiv1/requests.go index eff53a77..fdbb285e 100644 --- a/cas/apiv1/requests.go +++ b/cas/apiv1/requests.go @@ -81,6 +81,7 @@ type RenewCertificateRequest struct { CSR *x509.CertificateRequest Lifetime time.Duration Backdate time.Duration + Token string RequestID string } diff --git a/cas/apiv1/services.go b/cas/apiv1/services.go index f1d02b3c..f10a3e17 100644 --- a/cas/apiv1/services.go +++ b/cas/apiv1/services.go @@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string { func (e NotImplementedError) StatusCode() int { return http.StatusNotImplemented } + +// ValidationError is the type of error returned if request is not properly +// validated. +type ValidationError struct { + Message string +} + +// NotImplementedError implements the error interface. +func (e ValidationError) Error() string { + if e.Message != "" { + return e.Message + } + return "bad request" +} + +// StatusCode implements the StatusCoder interface and returns the HTTP 400 +// error. +func (e ValidationError) StatusCode() int { + return http.StatusBadRequest +} diff --git a/cas/apiv1/services_test.go b/cas/apiv1/services_test.go index f8e16138..9289de76 100644 --- a/cas/apiv1/services_test.go +++ b/cas/apiv1/services_test.go @@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) { }) } } + +func TestValidationError_Error(t *testing.T) { + type fields struct { + Message string + } + tests := []struct { + name string + fields fields + want string + }{ + {"default", fields{""}, "bad request"}, + {"with message", fields{"token is empty"}, "token is empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ValidationError{ + Message: tt.fields.Message, + } + if got := e.Error(); got != tt.want { + t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidationError_StatusCode(t *testing.T) { + type fields struct { + Message string + } + tests := []struct { + name string + fields fields + want int + }{ + {"default", fields{""}, 400}, + {"with message", fields{"token is empty"}, 400}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ValidationError{ + Message: tt.fields.Message, + } + if got := e.StatusCode(); got != tt.want { + t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go index 6c2acc84..c64963e6 100644 --- a/cas/stepcas/stepcas.go +++ b/cas/stepcas/stepcas.go @@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 // RenewCertificate will always return a non-implemented error as mTLS renewals // are not supported yet. func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { - return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"} + if req.Token == "" { + return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"} + } + + resp, err := s.client.RenewWithToken(req.Token) + if err != nil { + return nil, err + } + + var chain []*x509.Certificate + cert := resp.CertChainPEM[0].Certificate + for _, c := range resp.CertChainPEM[1:] { + chain = append(chain, c.Certificate) + } + + return &apiv1.RenewCertificateResponse{ + Certificate: cert, + CertificateChain: chain, + }, nil } // RevokeCertificate revokes a certificate. diff --git a/cas/stepcas/stepcas_test.go b/cas/stepcas/stepcas_test.go index cc8ea72e..6691a4b4 100644 --- a/cas/stepcas/stepcas_test.go +++ b/cas/stepcas/stepcas_test.go @@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) { writeJSON(w, api.SignResponse{ CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)}, }) + case r.RequestURI == "/renew": + if r.Header.Get("Authorization") == "Bearer fail" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `{"error":"fail","message":"fail"}`) + return + } + w.WriteHeader(http.StatusOK) + writeJSON(w, api.SignResponse{ + CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)}, + }) case r.RequestURI == "/revoke": var msg api.RevokeRequest parseJSON(r, &msg) @@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) { func TestStepCAS_RenewCertificate(t *testing.T) { caURL, client := testCAHelper(t) - x5c := testX5CIssuer(t, caURL, "") jwk := testJWKIssuer(t, caURL, "") + tokenIssuer := testX5CIssuer(t, caURL, "") + token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil) + if err != nil { + t.Fatal(err) + } + type fields struct { iss stepIssuer client *ca.Client @@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) { want *apiv1.RenewCertificateResponse wantErr bool }{ - {"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ - CSR: testCR, + {"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ + Template: &x509.Certificate{}, + Backdate: time.Minute, + Lifetime: time.Hour, + Token: token, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testCrt, + CertificateChain: []*x509.Certificate{testIssCrt}, + }, false}, + {"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ + Template: &x509.Certificate{}, + Backdate: time.Minute, Lifetime: time.Hour, }}, nil, true}, - {"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ - CSR: testCR, + {"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ + Template: &x509.Certificate{}, + Backdate: time.Minute, Lifetime: time.Hour, + Token: "fail", }}, nil, true}, } for _, tt := range tests { @@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) { return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want) + t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate)) + t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain)) + + t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject) } }) } diff --git a/db/db.go b/db/db.go index 784c75f4..b3137a50 100644 --- a/db/db.go +++ b/db/db.go @@ -28,8 +28,9 @@ var ( sshHostPrincipalsTable = []byte("ssh_host_principals") ) -var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table. -// is this acceptable? probably not.... +// TODO: at the moment we store a single CRL in the database, in a dedicated table. +// is this acceptable? probably not.... +var crlKey = []byte("crl") // ErrAlreadyExists can be returned if the DB attempts to set a key that has // been previously set. @@ -323,7 +324,8 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { // CertificateData is the JSON representation of the data stored in // x509_certs_data table. type CertificateData struct { - Provisioner *ProvisionerData `json:"provisioner,omitempty"` + Provisioner *ProvisionerData `json:"provisioner,omitempty"` + RaInfo *provisioner.RAInfo `json:"ra,omitempty"` } // ProvisionerData is the JSON representation of the provisioner stored in the @@ -334,6 +336,10 @@ type ProvisionerData struct { Type string `json:"type"` } +type raProvisioner interface { + RAInfo() *provisioner.RAInfo +} + // StoreCertificateChain stores the leaf certificate and the provisioner that // authorized the certificate. func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { @@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert Name: p.GetName(), Type: p.GetType().String(), } + if rap, ok := p.(raProvisioner); ok { + data.RaInfo = rap.RAInfo() + } } b, err := json.Marshal(data) if err != nil { @@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert return nil } +// StoreRenewedCertificate stores the leaf certificate and the provisioner that +// authorized the old certificate if available. +func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error { + var certificateData []byte + if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil { + if b, err := json.Marshal(data); err == nil { + certificateData = b + } + } + + leaf := chain[0] + serialNumber := []byte(leaf.SerialNumber.String()) + + // Add certificate and certificate data in one transaction. + tx := new(database.Tx) + tx.Set(certsTable, serialNumber, leaf.Raw) + if certificateData != nil { + tx.Set(certsDataTable, serialNumber, certificateData) + } + if err := db.Update(tx); err != nil { + return errors.Wrap(err, "database Update error") + } + return nil +} + // UseToken returns true if we were able to successfully store the token for // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { diff --git a/db/db_test.go b/db/db_test.go index b4515a5b..7668ae58 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "bytes" "crypto/x509" "errors" "math/big" @@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) { } } +// wrappedProvisioner implements raProvisioner and attProvisioner. +type wrappedProvisioner struct { + provisioner.Interface + raInfo *provisioner.RAInfo +} + +func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo { + return p.raInfo +} + func TestDB_StoreCertificateChain(t *testing.T) { p := &provisioner.JWK{ ID: "some-id", Name: "admin", Type: "JWK", } + rap := &wrappedProvisioner{ + Interface: p, + raInfo: &provisioner.RAInfo{ + ProvisionerID: "ra-id", + ProvisionerType: "JWK", + ProvisionerName: "ra", + }, + } chain := []*x509.Certificate{ {Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)}, } @@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) { return nil }, }, true}, args{p, chain}, false}, + {"ok ra provisioner", fields{&MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Fatal("unexpected number of operations") + } + assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[0].Key) + assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) + assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[1].Key) + assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value) + assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value)) + return nil + }, + }, true}, args{rap, chain}, false}, {"ok no provisioner", fields{&MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 2 { @@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) { }) } } + +func TestDB_StoreRenewedCertificate(t *testing.T) { + oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)} + chain := []*x509.Certificate{ + &x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")}, + &x509.Certificate{SerialNumber: big.NewInt(0)}, + } + + testErr := errors.New("test error") + certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`) + matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool { + return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value) + } + + type fields struct { + DB nosql.DB + isUp bool + } + type args struct { + oldCert *x509.Certificate + chain []*x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) { + return certsData, nil + } + t.Error("ok failed: unexpected get") + return nil, testErr + }, + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Error("ok failed: unexpected number of operations") + return testErr + } + op0, op1 := tx.Operations[0], tx.Operations[1] + if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { + t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) + return testErr + } + if !matchOperation(op1, certsDataTable, []byte("2"), certsData) { + t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value) + return testErr + } + return nil + }, + }, true}, args{oldCert, chain}, false}, + {"ok no data", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 1 { + t.Error("ok failed: unexpected number of operations") + return testErr + } + op0 := tx.Operations[0] + if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { + t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) + return testErr + } + return nil + }, + }, true}, args{oldCert, chain}, false}, + {"ok fail marshal", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return []byte(`{"bad":"json"`), nil + }, + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 1 { + t.Error("ok failed: unexpected number of operations") + return testErr + } + op0 := tx.Operations[0] + if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { + t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) + return testErr + } + return nil + }, + }, true}, args{oldCert, chain}, false}, + {"fail", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return certsData, nil + }, + MUpdate: func(tx *database.Tx) error { + return testErr + }, + }, true}, args{oldCert, chain}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &DB{ + DB: tt.fields.DB, + isUp: tt.fields.isUp, + } + if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr { + t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}