Add support for renew when using stepcas

It supports renewing X.509 certificates when an RA is configured with stepcas.
This will only work when the renewal uses a token, and it won't work with mTLS.

The audience cannot be properly verified when an RA is used, to avoid this we
will get from the database if an RA was used to issue the initial certificate
and we will accept the renew token.

Fixes #1021 for stepcas
pull/1156/head
Mariano Cano 1 year ago
parent 068a2dae8e
commit c7f226bcec
No known key found for this signature in database

@ -40,6 +40,7 @@ type Authority interface {
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
LoadProvisionerByName(string) (provisioner.Interface, error)

@ -192,6 +192,7 @@ type mockAuthority struct {
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
loadProvisionerByName func(name string) (provisioner.Interface, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.renewContext != nil {
return m.renewContext(ctx, oldcert, pk)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.rekey != nil {
return m.rekey(oldcert, pk)

@ -6,6 +6,7 @@ import (
"strings"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
)
@ -17,14 +18,22 @@ const (
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func Renew(w http.ResponseWriter, r *http.Request) {
cert, err := getPeerCertificate(r)
ctx := r.Context()
// Get the leaf certificate from the peer or the token.
cert, token, err := getPeerCertificate(r)
if err != nil {
render.Error(w, err)
return
}
a := mustAuthority(r.Context())
certChain, err := a.Renew(cert)
// The token can be used by RAs to renew a certificate.
if token != "" {
ctx = authority.NewTokenContext(ctx, token)
}
a := mustAuthority(ctx)
certChain, err := a.RenewContext(ctx, cert, nil)
if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) {
}, http.StatusCreated)
}
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
return r.TLS.PeerCertificates[0], "", nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
ctx := r.Context()
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
return peer, parts[1], err
}
}
return nil, errs.BadRequest("missing client certificate")
return nil, "", errs.BadRequest("missing client certificate")
}

@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
// extra extension cannot be found, authorize the renewal by default.
//
// TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error {
serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
}
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
if err := p.AuthorizeRenew(ctx, cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
}
return nil
@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
}
audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) {
if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) {
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}

@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
err := tc.auth.authorizeRenew(tc.cert)
err := tc.auth.authorizeRenew(context.Background(), tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError
@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
})
return nil
}))
a4 := testAuthority(t)
a4.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"},
RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}, nil
},
}
t4, c4 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://ra.example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
{"ok ra provisioner", a4, args{ctx, t4}, c4, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},

@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa
}
}
// wrapRAProvisioner wraps the given provisioner with RA information.
func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner {
return &wrappedProvisioner{
Interface: p,
raInfo: raInfo,
}
}
// isRAProvisioner returns if the given provisioner is an RA provisioner.
func isRAProvisioner(p provisioner.Interface) bool {
if rap, ok := p.(raProvisioner); ok {
return rap.RAInfo() != nil
}
return false
}
// wrappedProvisioner implements raProvisioner and attProvisioner.
type wrappedProvisioner struct {
provisioner.Interface
@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr
}
if err == nil && data != nil && data.Provisioner != nil {
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
if data.RaInfo != nil {
return wrapRAProvisioner(p, data.RaInfo), nil
}
return p, nil
}
}

@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) {
})
}
}
func Test_wrapRAProvisioner(t *testing.T) {
type args struct {
p provisioner.Interface
raInfo *provisioner.RAInfo
}
tests := []struct {
name string
args args
want *wrappedProvisioner
}{
{"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) {
t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isRAProvisioner(t *testing.T) {
type args struct {
p provisioner.Interface
}
tests := []struct {
name string
args args
want bool
}{
{"true", args{&wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}}, true},
{"nil ra", args{&wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
}}, false},
{"not ra", args{&provisioner.JWK{Name: "jwt"}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isRAProvisioner(tt.args.p); got != tt.want {
t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want)
}
})
}
}

@ -34,6 +34,19 @@ import (
"github.com/smallstep/nosql/database"
)
type tokenKey struct{}
// NewTokenContext adds the given token to the context.
func NewTokenContext(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, tokenKey{}, token)
}
// TokenFromContext returns the token from the given context.
func TokenFromContext(ctx context.Context) (token string, ok bool) {
token, ok = ctx.Value(tokenKey{}).(string)
return
}
// GetTLSOptions returns the tls options configured.
func (a *Authority) GetTLSOptions() *config.TLSOptions {
return a.config.TLS
@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error {
return a.policyEngine.AreSANsAllowed(sans)
}
// Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'.
// Renew creates a new Certificate identical to the old certificate, except with
// a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
return a.Rekey(oldCert, nil)
return a.RenewContext(context.Background(), oldCert, nil)
}
// Rekey is used for rekeying and renewing based on the public key.
// If the public key is 'nil' then it's assumed that the cert should be renewed
// using the existing public key. If the public key is not 'nil' then it's
// assumed that the cert should be rekeyed.
// Rekey is used for rekeying and renewing based on the public key. If the
// public key is 'nil' then it's assumed that the cert should be renewed using
// the existing public key. If the public key is not 'nil' then it's assumed
// that the cert should be rekeyed.
//
// For both Rekey and Renew all other attributes of the new certificate should
// match the old certificate. The exceptions are 'AuthorityKeyId' (which may
// have changed), 'SubjectKeyId' (different in case of rekey), and
// 'NotBefore/NotAfter' (the validity duration of the new certificate should be
// equal to the old one, but starting 'now').
func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
return a.RenewContext(context.Background(), oldCert, pk)
}
// RenewContext creates a new certificate identical to the old one, but it can
// optionally replace the public key with the given one. When running on RA
// mode, it can only renew a certificate using a renew token instead.
//
// For both rekey and renew operations, all other attributes of the new
// certificate should match the old certificate. The exceptions are
// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case
// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new
// certificate should be equal to the old one, but starting 'now').
func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
isRekey := (pk != nil)
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
opts := []errs.Option{
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
}
// Check step provisioner extensions
if err := a.authorizeRenew(oldCert); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
if err := a.authorizeRenew(ctx, oldCert); err != nil {
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
}
// Durations
@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil {
var ee *errs.Error
if errors.As(err, &ee) {
return nil, errs.ApplyOptions(ee, opts...)
return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...)
}
return nil, errs.InternalServerErr(err,
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
)
}
// The token can optionally be in the context. If the CA is running in RA
// mode, this can be used to renew a certificate.
token, _ := TokenFromContext(ctx)
resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{
Template: newCert,
Lifetime: lifetime,
Backdate: backdate,
Token: token,
})
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
}
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
if !errors.Is(err, db.ErrNotImplemented) {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
}
}

@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) {
return &renewTest{
auth: _a,
cert: cert,
err: errors.New("authority.Rekey: error creating certificate"),
err: errors.New("error creating certificate"),
code: http.StatusInternalServerError,
}, nil
},
"fail/unauthorized": func() (*renewTest, error) {
return &renewTest{
cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) {
return &renewTest{
auth: aa,
cert: cert,
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
err: errors.New("authority.authorizeRenew: not authorized"),
code: http.StatusUnauthorized,
}, nil
},
@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) {
return &renewTest{
auth: _a,
cert: cert,
err: errors.New("authority.Rekey: error creating certificate"),
err: errors.New("error creating certificate"),
code: http.StatusInternalServerError,
}, nil
},
"fail/unauthorized": func() (*renewTest, error) {
return &renewTest{
cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},

@ -81,6 +81,7 @@ type RenewCertificateRequest struct {
CSR *x509.CertificateRequest
Lifetime time.Duration
Backdate time.Duration
Token string
RequestID string
}

@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string {
func (e NotImplementedError) StatusCode() int {
return http.StatusNotImplemented
}
// ValidationError is the type of error returned if request is not properly
// validated.
type ValidationError struct {
Message string
}
// NotImplementedError implements the error interface.
func (e ValidationError) Error() string {
if e.Message != "" {
return e.Message
}
return "bad request"
}
// StatusCode implements the StatusCoder interface and returns the HTTP 400
// error.
func (e ValidationError) StatusCode() int {
return http.StatusBadRequest
}

@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) {
})
}
}
func TestValidationError_Error(t *testing.T) {
type fields struct {
Message string
}
tests := []struct {
name string
fields fields
want string
}{
{"default", fields{""}, "bad request"},
{"with message", fields{"token is empty"}, "token is empty"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ValidationError{
Message: tt.fields.Message,
}
if got := e.Error(); got != tt.want {
t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want)
}
})
}
}
func TestValidationError_StatusCode(t *testing.T) {
type fields struct {
Message string
}
tests := []struct {
name string
fields fields
want int
}{
{"default", fields{""}, 400},
{"with message", fields{"token is empty"}, 400},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ValidationError{
Message: tt.fields.Message,
}
if got := e.StatusCode(); got != tt.want {
t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want)
}
})
}
}

@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
// RenewCertificate will always return a non-implemented error as mTLS renewals
// are not supported yet.
func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"}
if req.Token == "" {
return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"}
}
resp, err := s.client.RenewWithToken(req.Token)
if err != nil {
return nil, err
}
var chain []*x509.Certificate
cert := resp.CertChainPEM[0].Certificate
for _, c := range resp.CertChainPEM[1:] {
chain = append(chain, c.Certificate)
}
return &apiv1.RenewCertificateResponse{
Certificate: cert,
CertificateChain: chain,
}, nil
}
// RevokeCertificate revokes a certificate.

@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) {
writeJSON(w, api.SignResponse{
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
})
case r.RequestURI == "/renew":
if r.Header.Get("Authorization") == "Bearer fail" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `{"error":"fail","message":"fail"}`)
return
}
w.WriteHeader(http.StatusOK)
writeJSON(w, api.SignResponse{
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
})
case r.RequestURI == "/revoke":
var msg api.RevokeRequest
parseJSON(r, &msg)
@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) {
func TestStepCAS_RenewCertificate(t *testing.T) {
caURL, client := testCAHelper(t)
x5c := testX5CIssuer(t, caURL, "")
jwk := testJWKIssuer(t, caURL, "")
tokenIssuer := testX5CIssuer(t, caURL, "")
token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil)
if err != nil {
t.Fatal(err)
}
type fields struct {
iss stepIssuer
client *ca.Client
@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
want *apiv1.RenewCertificateResponse
wantErr bool
}{
{"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
CSR: testCR,
{"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
Template: &x509.Certificate{},
Backdate: time.Minute,
Lifetime: time.Hour,
Token: token,
}}, &apiv1.RenewCertificateResponse{
Certificate: testCrt,
CertificateChain: []*x509.Certificate{testIssCrt},
}, false},
{"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
Template: &x509.Certificate{},
Backdate: time.Minute,
Lifetime: time.Hour,
}}, nil, true},
{"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
CSR: testCR,
{"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
Template: &x509.Certificate{},
Backdate: time.Minute,
Lifetime: time.Hour,
Token: "fail",
}}, nil, true},
}
for _, tt := range tests {
@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want)
t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate))
t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain))
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject)
}
})
}

@ -28,8 +28,9 @@ var (
sshHostPrincipalsTable = []byte("ssh_host_principals")
)
var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table.
// is this acceptable? probably not....
// TODO: at the moment we store a single CRL in the database, in a dedicated table.
// is this acceptable? probably not....
var crlKey = []byte("crl")
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
// been previously set.
@ -323,7 +324,8 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
// CertificateData is the JSON representation of the data stored in
// x509_certs_data table.
type CertificateData struct {
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
RaInfo *provisioner.RAInfo `json:"ra,omitempty"`
}
// ProvisionerData is the JSON representation of the provisioner stored in the
@ -334,6 +336,10 @@ type ProvisionerData struct {
Type string `json:"type"`
}
type raProvisioner interface {
RAInfo() *provisioner.RAInfo
}
// StoreCertificateChain stores the leaf certificate and the provisioner that
// authorized the certificate.
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
Name: p.GetName(),
Type: p.GetType().String(),
}
if rap, ok := p.(raProvisioner); ok {
data.RaInfo = rap.RAInfo()
}
}
b, err := json.Marshal(data)
if err != nil {
@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
return nil
}
// StoreRenewedCertificate stores the leaf certificate and the provisioner that
// authorized the old certificate if available.
func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error {
var certificateData []byte
if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil {
if b, err := json.Marshal(data); err == nil {
certificateData = b
}
}
leaf := chain[0]
serialNumber := []byte(leaf.SerialNumber.String())
// Add certificate and certificate data in one transaction.
tx := new(database.Tx)
tx.Set(certsTable, serialNumber, leaf.Raw)
if certificateData != nil {
tx.Set(certsDataTable, serialNumber, certificateData)
}
if err := db.Update(tx); err != nil {
return errors.Wrap(err, "database Update error")
}
return nil
}
// UseToken returns true if we were able to successfully store the token for
// for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) {

@ -1,6 +1,7 @@
package db
import (
"bytes"
"crypto/x509"
"errors"
"math/big"
@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) {
}
}
// wrappedProvisioner implements raProvisioner and attProvisioner.
type wrappedProvisioner struct {
provisioner.Interface
raInfo *provisioner.RAInfo
}
func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo {
return p.raInfo
}
func TestDB_StoreCertificateChain(t *testing.T) {
p := &provisioner.JWK{
ID: "some-id",
Name: "admin",
Type: "JWK",
}
rap := &wrappedProvisioner{
Interface: p,
raInfo: &provisioner.RAInfo{
ProvisionerID: "ra-id",
ProvisionerType: "JWK",
ProvisionerName: "ra",
},
}
chain := []*x509.Certificate{
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
}
@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) {
return nil
},
}, true}, args{p, chain}, false},
{"ok ra provisioner", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Fatal("unexpected number of operations")
}
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value)
assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value))
return nil
},
}, true}, args{rap, chain}, false},
{"ok no provisioner", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) {
})
}
}
func TestDB_StoreRenewedCertificate(t *testing.T) {
oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)}
chain := []*x509.Certificate{
&x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")},
&x509.Certificate{SerialNumber: big.NewInt(0)},
}
testErr := errors.New("test error")
certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`)
matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool {
return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value)
}
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
oldCert *x509.Certificate
chain []*x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) {
return certsData, nil
}
t.Error("ok failed: unexpected get")
return nil, testErr
},
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Error("ok failed: unexpected number of operations")
return testErr
}
op0, op1 := tx.Operations[0], tx.Operations[1]
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
return testErr
}
if !matchOperation(op1, certsDataTable, []byte("2"), certsData) {
t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value)
return testErr
}
return nil
},
}, true}, args{oldCert, chain}, false},
{"ok no data", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 1 {
t.Error("ok failed: unexpected number of operations")
return testErr
}
op0 := tx.Operations[0]
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
return testErr
}
return nil
},
}, true}, args{oldCert, chain}, false},
{"ok fail marshal", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return []byte(`{"bad":"json"`), nil
},
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 1 {
t.Error("ok failed: unexpected number of operations")
return testErr
}
op0 := tx.Operations[0]
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
return testErr
}
return nil
},
}, true}, args{oldCert, chain}, false},
{"fail", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return certsData, nil
},
MUpdate: func(tx *database.Tx) error {
return testErr
},
}, true}, args{oldCert, chain}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr {
t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

Loading…
Cancel
Save