diff --git a/db/db.go b/db/db.go index 6d48723f..3427d2bb 100644 --- a/db/db.go +++ b/db/db.go @@ -8,20 +8,22 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" "golang.org/x/crypto/ssh" ) var ( - certsTable = []byte("x509_certs") - revokedCertsTable = []byte("revoked_x509_certs") - revokedSSHCertsTable = []byte("revoked_ssh_certs") - usedOTTTable = []byte("used_ott") - sshCertsTable = []byte("ssh_certs") - sshHostsTable = []byte("ssh_hosts") - sshUsersTable = []byte("ssh_users") - sshHostPrincipalsTable = []byte("ssh_host_principals") + certsTable = []byte("x509_certs") + certsToProvisionerTable = []byte("x509_certs_provisioner") + revokedCertsTable = []byte("revoked_x509_certs") + revokedSSHCertsTable = []byte("revoked_ssh_certs") + usedOTTTable = []byte("used_ott") + sshCertsTable = []byte("ssh_certs") + sshHostsTable = []byte("ssh_hosts") + sshUsersTable = []byte("ssh_users") + sshHostPrincipalsTable = []byte("ssh_host_principals") ) // ErrAlreadyExists can be returned if the DB attempts to set a key that has @@ -82,7 +84,7 @@ func New(c *Config) (AuthDB, error) { tables := [][]byte{ revokedCertsTable, certsTable, usedOTTTable, sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable, - revokedSSHCertsTable, + revokedSSHCertsTable, certsToProvisionerTable, } for _, b := range tables { if err := db.CreateTable(b); err != nil { @@ -210,6 +212,36 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { return nil } +type certsToProvionersData struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` +} + +// StoreCertificateChain stores the leaf certificate and the provisioner that +// authorized the certificate. +func (d *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { + leaf := chain[0] + if err := d.StoreCertificate(leaf); err != nil { + return err + } + if p != nil { + b, err := json.Marshal(certsToProvionersData{ + ID: p.GetID(), + Name: p.GetName(), + Type: p.GetType().String(), + }) + if err != nil { + return errors.Wrap(err, "error marshaling json") + } + + if err := d.Set(certsToProvisionerTable, []byte(leaf.SerialNumber.String()), b); err != nil { + return errors.Wrap(err, "database Set error") + } + } + return nil +} + // UseToken returns true if we were able to successfully store the token for // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { diff --git a/db/db_test.go b/db/db_test.go index 40f59215..5a7e2d38 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,10 +1,14 @@ package db import ( + "crypto/x509" "errors" + "math/big" "testing" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) @@ -158,3 +162,87 @@ func TestUseToken(t *testing.T) { }) } } + +func TestDB_StoreCertificateChain(t *testing.T) { + p := &provisioner.JWK{ + ID: "some-id", + Name: "admin", + Type: "JWK", + } + chain := []*x509.Certificate{ + {Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)}, + } + type fields struct { + DB nosql.DB + isUp bool + } + type args struct { + p provisioner.Interface + chain []*x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&MockNoSQLDB{ + MSet: func(bucket, key, value []byte) error { + switch string(bucket) { + case "x509_certs": + assert.Equals(t, key, []byte("1234")) + assert.Equals(t, value, []byte("the certificate")) + case "x509_certs_provisioner": + assert.Equals(t, key, []byte("1234")) + assert.Equals(t, value, []byte(`{"id":"some-id","name":"admin","type":"JWK"}`)) + default: + t.Errorf("unexpected bucket %s", bucket) + } + return nil + }, + }, true}, args{p, chain}, false}, + {"ok no provisioner", fields{&MockNoSQLDB{ + MSet: func(bucket, key, value []byte) error { + switch string(bucket) { + case "x509_certs": + assert.Equals(t, key, []byte("1234")) + assert.Equals(t, value, []byte("the certificate")) + default: + t.Errorf("unexpected bucket %s", bucket) + } + return nil + }, + }, true}, args{nil, chain}, false}, + {"fail store certificate", fields{&MockNoSQLDB{ + MSet: func(bucket, key, value []byte) error { + switch string(bucket) { + case "x509_certs": + return errors.New("test error") + default: + return nil + } + }, + }, true}, args{p, chain}, true}, + {"fail store provisioner", fields{&MockNoSQLDB{ + MSet: func(bucket, key, value []byte) error { + switch string(bucket) { + case "x509_certs_provisioner": + return errors.New("test error") + default: + return nil + } + }, + }, true}, args{p, chain}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &DB{ + DB: tt.fields.DB, + isUp: tt.fields.isUp, + } + if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr { + t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}