mirror of
https://github.com/smallstep/certificates.git
synced 2024-10-31 03:20:16 +00:00
Add SSH getHosts api
This commit is contained in:
parent
ded8087042
commit
5616386eed
@ -257,6 +257,7 @@ func (h *caHandler) Route(r Router) {
|
|||||||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
||||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
||||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
||||||
|
r.MethodFunc("POST", "/ssh/get-hosts", h.SSHGetHosts)
|
||||||
|
|
||||||
// For compatibility with old code:
|
// For compatibility with old code:
|
||||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||||
|
18
api/ssh.go
18
api/ssh.go
@ -21,6 +21,7 @@ type SSHAuthority interface {
|
|||||||
GetSSHFederation() (*authority.SSHKeys, error)
|
GetSSHFederation() (*authority.SSHKeys, error)
|
||||||
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
||||||
CheckSSHHost(principal string) (bool, error)
|
CheckSSHHost(principal string) (bool, error)
|
||||||
|
GetSSHHosts() ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHSignRequest is the request body of an SSH certificate request.
|
// SSHSignRequest is the request body of an SSH certificate request.
|
||||||
@ -66,6 +67,11 @@ type SSHCertificate struct {
|
|||||||
*ssh.Certificate `json:"omitempty"`
|
*ssh.Certificate `json:"omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHGetHostsResponse
|
||||||
|
type SSHGetHostsResponse struct {
|
||||||
|
Hosts []string `json:"hosts"`
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||||
// base64 encoded, openssh wire format version of the certificate.
|
// base64 encoded, openssh wire format version of the certificate.
|
||||||
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||||
@ -369,3 +375,15 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||||||
Exists: exists,
|
Exists: exists,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||||
|
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hosts, err := h.Authority.GetSSHHosts()
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, InternalServerError(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
JSON(w, &SSHGetHostsResponse{
|
||||||
|
Hosts: hosts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -369,6 +369,16 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
|
|||||||
return exists, nil
|
return exists, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSSHHosts returns a list of valid host principals.
|
||||||
|
func (a *Authority) GetSSHHosts() ([]string, error) {
|
||||||
|
ps, err := a.db.GetSSHHostPrincipals()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ps, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||||
if a.config.SSH.AddUserPrincipal == "" {
|
if a.config.SSH.AddUserPrincipal == "" {
|
||||||
return SSHAddUserPrincipal
|
return SSHAddUserPrincipal
|
||||||
|
18
ca/client.go
18
ca/client.go
@ -611,6 +611,24 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
|
|||||||
return &check, nil
|
return &check, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHGetHostPrincipals performs the POST /ssh/check-host request to the CA with the
|
||||||
|
// given principal.
|
||||||
|
func (c *Client) SSHGetHostPrincipals() (*api.SSHGetHostsResponse, error) {
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"})
|
||||||
|
resp, err := c.client.Get(u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var hosts api.SSHGetHostsResponse
|
||||||
|
if err := readJSON(resp.Body, &hosts); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
|
}
|
||||||
|
return &hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
// RootFingerprint is a helper method that returns the current root fingerprint.
|
// RootFingerprint is a helper method that returns the current root fingerprint.
|
||||||
// It does an health connection and gets the fingerprint from the TLS verified
|
// It does an health connection and gets the fingerprint from the TLS verified
|
||||||
// chains.
|
// chains.
|
||||||
|
46
db/db.go
46
db/db.go
@ -20,6 +20,7 @@ var (
|
|||||||
sshCertsTable = []byte("ssh_certs")
|
sshCertsTable = []byte("ssh_certs")
|
||||||
sshHostsTable = []byte("ssh_hosts")
|
sshHostsTable = []byte("ssh_hosts")
|
||||||
sshUsersTable = []byte("ssh_users")
|
sshUsersTable = []byte("ssh_users")
|
||||||
|
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
||||||
@ -42,6 +43,7 @@ type AuthDB interface {
|
|||||||
UseToken(id, tok string) (bool, error)
|
UseToken(id, tok string) (bool, error)
|
||||||
IsSSHHost(name string) (bool, error)
|
IsSSHHost(name string) (bool, error)
|
||||||
StoreSSHCertificate(crt *ssh.Certificate) error
|
StoreSSHCertificate(crt *ssh.Certificate) error
|
||||||
|
GetSSHHostPrincipals() ([]string, error)
|
||||||
Shutdown() error
|
Shutdown() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,19 +162,32 @@ func (db *DB) IsSSHHost(principal string) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sshHostPrincipalData struct {
|
||||||
|
Serial string
|
||||||
|
Expiry uint64
|
||||||
|
}
|
||||||
|
|
||||||
// StoreSSHCertificate stores an SSH certificate.
|
// StoreSSHCertificate stores an SSH certificate.
|
||||||
func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||||
var table []byte
|
|
||||||
serial := strconv.FormatUint(crt.Serial, 10)
|
serial := strconv.FormatUint(crt.Serial, 10)
|
||||||
tx := new(database.Tx)
|
tx := new(database.Tx)
|
||||||
tx.Set(sshCertsTable, []byte(serial), crt.Marshal())
|
tx.Set(sshCertsTable, []byte(serial), crt.Marshal())
|
||||||
if crt.CertType == ssh.HostCert {
|
if crt.CertType == ssh.HostCert {
|
||||||
table = sshHostsTable
|
|
||||||
} else {
|
|
||||||
table = sshUsersTable
|
|
||||||
}
|
|
||||||
for _, p := range crt.ValidPrincipals {
|
for _, p := range crt.ValidPrincipals {
|
||||||
tx.Set(table, []byte(strings.ToLower(p)), []byte(serial))
|
hostPrincipalData, err := json.Marshal(sshHostPrincipalData{
|
||||||
|
Serial: serial,
|
||||||
|
Expiry: crt.ValidBefore,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial))
|
||||||
|
tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, p := range crt.ValidPrincipals {
|
||||||
|
tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := db.Update(tx); err != nil {
|
if err := db.Update(tx); err != nil {
|
||||||
return errors.Wrap(err, "database Update error")
|
return errors.Wrap(err, "database Update error")
|
||||||
@ -181,6 +196,25 @@ func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSSHHostPrincipals gets a list of all valid host principals.
|
||||||
|
func (db *DB) GetSSHHostPrincipals() ([]string, error) {
|
||||||
|
entries, err := db.List(sshHostPrincipalsTable)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var principals []string
|
||||||
|
for _, e := range entries {
|
||||||
|
var data sshHostPrincipalData
|
||||||
|
if err := json.Unmarshal(e.Value, &data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if time.Unix(int64(data.Expiry), 0).After(time.Now()) {
|
||||||
|
principals = append(principals, string(e.Key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return principals, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Shutdown sends a shutdown message to the database.
|
// Shutdown sends a shutdown message to the database.
|
||||||
func (db *DB) Shutdown() error {
|
func (db *DB) Shutdown() error {
|
||||||
if db.isUp {
|
if db.isUp {
|
||||||
|
@ -69,6 +69,11 @@ func (s *SimpleDB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
|||||||
return ErrNotImplemented
|
return ErrNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSSHHostPrincipals returns a "NotImplemented" error.
|
||||||
|
func (s *SimpleDB) GetSSHHostPrincipals() ([]string, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
// Shutdown returns nil
|
// Shutdown returns nil
|
||||||
func (s *SimpleDB) Shutdown() error {
|
func (s *SimpleDB) Shutdown() error {
|
||||||
return nil
|
return nil
|
||||||
|
Loading…
Reference in New Issue
Block a user