|
|
|
@ -14,12 +14,13 @@ import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
certsTable = []byte("x509_certs")
|
|
|
|
|
revokedCertsTable = []byte("revoked_x509_certs")
|
|
|
|
|
usedOTTTable = []byte("used_ott")
|
|
|
|
|
sshCertsTable = []byte("ssh_certs")
|
|
|
|
|
sshHostsTable = []byte("ssh_hosts")
|
|
|
|
|
sshUsersTable = []byte("ssh_users")
|
|
|
|
|
certsTable = []byte("x509_certs")
|
|
|
|
|
revokedCertsTable = []byte("revoked_x509_certs")
|
|
|
|
|
usedOTTTable = []byte("used_ott")
|
|
|
|
|
sshCertsTable = []byte("ssh_certs")
|
|
|
|
|
sshHostsTable = []byte("ssh_hosts")
|
|
|
|
|
sshUsersTable = []byte("ssh_users")
|
|
|
|
|
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
|
|
|
@ -42,6 +43,7 @@ type AuthDB interface {
|
|
|
|
|
UseToken(id, tok string) (bool, error)
|
|
|
|
|
IsSSHHost(name string) (bool, error)
|
|
|
|
|
StoreSSHCertificate(crt *ssh.Certificate) error
|
|
|
|
|
GetSSHHostPrincipals() ([]string, error)
|
|
|
|
|
Shutdown() error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -160,19 +162,32 @@ func (db *DB) IsSSHHost(principal string) (bool, error) {
|
|
|
|
|
return true, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type sshHostPrincipalData struct {
|
|
|
|
|
Serial string
|
|
|
|
|
Expiry uint64
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// StoreSSHCertificate stores an SSH certificate.
|
|
|
|
|
func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
|
|
|
|
var table []byte
|
|
|
|
|
serial := strconv.FormatUint(crt.Serial, 10)
|
|
|
|
|
tx := new(database.Tx)
|
|
|
|
|
tx.Set(sshCertsTable, []byte(serial), crt.Marshal())
|
|
|
|
|
if crt.CertType == ssh.HostCert {
|
|
|
|
|
table = sshHostsTable
|
|
|
|
|
for _, p := range crt.ValidPrincipals {
|
|
|
|
|
hostPrincipalData, err := json.Marshal(sshHostPrincipalData{
|
|
|
|
|
Serial: serial,
|
|
|
|
|
Expiry: crt.ValidBefore,
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial))
|
|
|
|
|
tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData)
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
table = sshUsersTable
|
|
|
|
|
}
|
|
|
|
|
for _, p := range crt.ValidPrincipals {
|
|
|
|
|
tx.Set(table, []byte(strings.ToLower(p)), []byte(serial))
|
|
|
|
|
for _, p := range crt.ValidPrincipals {
|
|
|
|
|
tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if err := db.Update(tx); err != nil {
|
|
|
|
|
return errors.Wrap(err, "database Update error")
|
|
|
|
@ -181,6 +196,25 @@ func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetSSHHostPrincipals gets a list of all valid host principals.
|
|
|
|
|
func (db *DB) GetSSHHostPrincipals() ([]string, error) {
|
|
|
|
|
entries, err := db.List(sshHostPrincipalsTable)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
var principals []string
|
|
|
|
|
for _, e := range entries {
|
|
|
|
|
var data sshHostPrincipalData
|
|
|
|
|
if err := json.Unmarshal(e.Value, &data); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if time.Unix(int64(data.Expiry), 0).After(time.Now()) {
|
|
|
|
|
principals = append(principals, string(e.Key))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return principals, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Shutdown sends a shutdown message to the database.
|
|
|
|
|
func (db *DB) Shutdown() error {
|
|
|
|
|
if db.isUp {
|
|
|
|
|