commit
7726f5ec75
@ -0,0 +1,159 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||||
|
type SSHAuthority interface {
|
||||||
|
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
|
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHRequest is the request body of an SSH certificate request.
|
||||||
|
type SignSSHRequest struct {
|
||||||
|
PublicKey []byte `json:"publicKey"` //base64 encoded
|
||||||
|
OTT string `json:"ott"`
|
||||||
|
CertType string `json:"certType,omitempty"`
|
||||||
|
Principals []string `json:"principals,omitempty"`
|
||||||
|
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||||
|
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||||
|
AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHResponse is the response object that returns the SSH certificate.
|
||||||
|
type SignSSHResponse struct {
|
||||||
|
Certificate SSHCertificate `json:"crt"`
|
||||||
|
AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificate represents the response SSH certificate.
|
||||||
|
type SSHCertificate struct {
|
||||||
|
*ssh.Certificate `json:"omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||||
|
// base64 encoded, openssh wire format version of the certificate.
|
||||||
|
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||||
|
if c.Certificate == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal())
|
||||||
|
return []byte(`"` + s + `"`), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is
|
||||||
|
// expected to be a quoted, base64 encoded, openssh wire formatted block of bytes.
|
||||||
|
func (c *SSHCertificate) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding certificate")
|
||||||
|
}
|
||||||
|
if s == "" {
|
||||||
|
c.Certificate = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
certData, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding ssh certificate")
|
||||||
|
}
|
||||||
|
pub, err := ssh.ParsePublicKey(certData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error parsing ssh certificate")
|
||||||
|
}
|
||||||
|
cert, ok := pub.(*ssh.Certificate)
|
||||||
|
if !ok {
|
||||||
|
return errors.Errorf("error decoding ssh certificate: %T is not an *ssh.Certificate", pub)
|
||||||
|
}
|
||||||
|
c.Certificate = cert
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the SignSSHRequest.
|
||||||
|
func (s *SignSSHRequest) Validate() error {
|
||||||
|
switch {
|
||||||
|
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
|
||||||
|
return errors.Errorf("unknown certType %s", s.CertType)
|
||||||
|
case len(s.PublicKey) == 0:
|
||||||
|
return errors.New("missing or empty publicKey")
|
||||||
|
case len(s.OTT) == 0:
|
||||||
|
return errors.New("missing or empty ott")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||||
|
// (ott) from the body and creates a new SSH certificate with the information in
|
||||||
|
// the request.
|
||||||
|
func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var body SignSSHRequest
|
||||||
|
if err := ReadJSON(r.Body, &body); err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOtt(w, body.OTT)
|
||||||
|
if err := body.Validate(); err != nil {
|
||||||
|
WriteError(w, BadRequest(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addUserPublicKey ssh.PublicKey
|
||||||
|
if body.AddUserPublicKey != nil {
|
||||||
|
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := provisioner.SSHOptions{
|
||||||
|
CertType: body.CertType,
|
||||||
|
Principals: body.Principals,
|
||||||
|
ValidBefore: body.ValidBefore,
|
||||||
|
ValidAfter: body.ValidAfter,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
|
||||||
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Unauthorized(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Forbidden(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addUserCertificate *SSHCertificate
|
||||||
|
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||||
|
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Forbidden(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addUserCertificate = &SSHCertificate{addUserCert}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
JSON(w, &SignSSHResponse{
|
||||||
|
Certificate: SSHCertificate{cert},
|
||||||
|
AddUserCertificate: addUserCertificate,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,327 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sshSignerKey = mustKey()
|
||||||
|
sshUserKey = mustKey()
|
||||||
|
sshHostKey = mustKey()
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustKey() *ecdsa.PrivateKey {
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return priv
|
||||||
|
}
|
||||||
|
|
||||||
|
func signSSHCertificate(cert *ssh.Certificate) error {
|
||||||
|
signerKey, err := ssh.NewPublicKey(sshSignerKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
signer, err := ssh.NewSignerFromSigner(sshSignerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signerKey
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSignedUserCertificate() (*ssh.Certificate, error) {
|
||||||
|
key, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := time.Now()
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte("1234567890"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 1234567890,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "user@localhost",
|
||||||
|
ValidPrincipals: []string{"user"},
|
||||||
|
ValidAfter: uint64(t.Unix()),
|
||||||
|
ValidBefore: uint64(t.Add(time.Hour).Unix()),
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{},
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Reserved: []byte{},
|
||||||
|
}
|
||||||
|
if err := signSSHCertificate(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSignedHostCertificate() (*ssh.Certificate, error) {
|
||||||
|
key, err := ssh.NewPublicKey(sshHostKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := time.Now()
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte("1234567890"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 1234567890,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "internal.smallstep.com",
|
||||||
|
ValidPrincipals: []string{"internal.smallstep.com"},
|
||||||
|
ValidAfter: uint64(t.Unix()),
|
||||||
|
ValidBefore: uint64(t.Add(time.Hour).Unix()),
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{},
|
||||||
|
Extensions: map[string]string{},
|
||||||
|
},
|
||||||
|
Reserved: []byte{},
|
||||||
|
}
|
||||||
|
if err := signSSHCertificate(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCertificate_MarshalJSON(t *testing.T) {
|
||||||
|
user, err := getSignedUserCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
host, err := getSignedHostCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Certificate *ssh.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||||||
|
{"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false},
|
||||||
|
{"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := SSHCertificate{
|
||||||
|
Certificate: tt.fields.Certificate,
|
||||||
|
}
|
||||||
|
got, err := c.MarshalJSON()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
|
||||||
|
user, err := getSignedUserCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
host, err := getSignedHostCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal())
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *ssh.Certificate
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"null", args{[]byte(`null`)}, nil, false},
|
||||||
|
{"empty", args{[]byte(`""`)}, nil, false},
|
||||||
|
{"user", args{[]byte(`"` + userB64 + `"`)}, user, false},
|
||||||
|
{"host", args{[]byte(`"` + hostB64 + `"`)}, host, false},
|
||||||
|
{"bad-string", args{[]byte(userB64)}, nil, true},
|
||||||
|
{"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true},
|
||||||
|
{"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true},
|
||||||
|
{"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &SSHCertificate{}
|
||||||
|
if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(tt.want, c.Certificate) {
|
||||||
|
t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
PublicKey []byte
|
||||||
|
OTT string
|
||||||
|
CertType string
|
||||||
|
Principals []string
|
||||||
|
ValidAfter TimeDuration
|
||||||
|
ValidBefore TimeDuration
|
||||||
|
AddUserPublicKey []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||||
|
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||||
|
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||||
|
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := &SignSSHRequest{
|
||||||
|
PublicKey: tt.fields.PublicKey,
|
||||||
|
OTT: tt.fields.OTT,
|
||||||
|
CertType: tt.fields.CertType,
|
||||||
|
Principals: tt.fields.Principals,
|
||||||
|
ValidAfter: tt.fields.ValidAfter,
|
||||||
|
ValidBefore: tt.fields.ValidBefore,
|
||||||
|
AddUserPublicKey: tt.fields.AddUserPublicKey,
|
||||||
|
}
|
||||||
|
if err := s.Validate(); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_SignSSH(t *testing.T) {
|
||||||
|
user, err := getSignedUserCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
host, err := getSignedHostCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
|
||||||
|
userReq, err := json.Marshal(SignSSHRequest{
|
||||||
|
PublicKey: user.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hostReq, err := json.Marshal(SignSSHRequest{
|
||||||
|
PublicKey: host.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userAddReq, err := json.Marshal(SignSSHRequest{
|
||||||
|
PublicKey: user.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
AddUserPublicKey: user.Key.Marshal(),
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Authority Authority
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
r *http.Request
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req []byte
|
||||||
|
authErr error
|
||||||
|
signCert *ssh.Certificate
|
||||||
|
signErr error
|
||||||
|
addUserCert *ssh.Certificate
|
||||||
|
addUserErr error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
|
||||||
|
{"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
|
||||||
|
{"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
|
||||||
|
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
||||||
|
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
||||||
|
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{
|
||||||
|
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||||
|
return []provisioner.SignOption{}, tt.authErr
|
||||||
|
},
|
||||||
|
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
return tt.signCert, tt.signErr
|
||||||
|
},
|
||||||
|
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
|
return tt.addUserCert, tt.addUserErr
|
||||||
|
},
|
||||||
|
}).(*caHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SignSSH(logging.NewResponseLogger(w), req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
|
t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Method indicates the action to action that we will perform, it's used as part
|
||||||
|
// of the context in the call to authorize. It defaults to Sing.
|
||||||
|
type Method int
|
||||||
|
|
||||||
|
// The key to save the Method in the context.
|
||||||
|
type methodKey struct{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SignMethod is the method used to sign X.509 certificates.
|
||||||
|
SignMethod Method = iota
|
||||||
|
// SignSSHMethod is the method used to sign SSH certificate.
|
||||||
|
SignSSHMethod
|
||||||
|
// RevokeMethod is the method used to revoke X.509 certificates.
|
||||||
|
RevokeMethod
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewContextWithMethod creates a new context from ctx and attaches method to
|
||||||
|
// it.
|
||||||
|
func NewContextWithMethod(ctx context.Context, method Method) context.Context {
|
||||||
|
return context.WithValue(ctx, methodKey{}, method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodFromContext returns the Method saved in ctx. Returns Sign if the given
|
||||||
|
// context has no Method associated with it.
|
||||||
|
func MethodFromContext(ctx context.Context) Method {
|
||||||
|
m, _ := ctx.Value(methodKey{}).(Method)
|
||||||
|
return m
|
||||||
|
}
|
@ -0,0 +1,306 @@
|
|||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SSHUserCert is the string used to represent ssh.UserCert.
|
||||||
|
SSHUserCert = "user"
|
||||||
|
|
||||||
|
// SSHHostCert is the string used to represent ssh.HostCert.
|
||||||
|
SSHHostCert = "host"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHCertificateModifier is the interface used to change properties in an SSH
|
||||||
|
// certificate.
|
||||||
|
type SSHCertificateModifier interface {
|
||||||
|
SignOption
|
||||||
|
Modify(cert *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateOptionModifier is the interface used to add custom options used
|
||||||
|
// to modify the SSH certificate.
|
||||||
|
type SSHCertificateOptionModifier interface {
|
||||||
|
SignOption
|
||||||
|
Option(o SSHOptions) SSHCertificateModifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateValidator is the interface used to validate an SSH certificate.
|
||||||
|
type SSHCertificateValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(cert *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateOptionsValidator is the interface used to validate the custom
|
||||||
|
// options used to modify the SSH certificate.
|
||||||
|
type SSHCertificateOptionsValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(got SSHOptions) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHOptions contains the options that can be passed to the SignSSH method.
|
||||||
|
type SSHOptions struct {
|
||||||
|
CertType string `json:"certType"`
|
||||||
|
Principals []string `json:"principals"`
|
||||||
|
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||||
|
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the uint32 representation of the CertType.
|
||||||
|
func (o SSHOptions) Type() uint32 {
|
||||||
|
return sshCertTypeUInt32(o.CertType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
|
||||||
|
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
||||||
|
switch o.CertType {
|
||||||
|
case "": // ignore
|
||||||
|
case SSHUserCert:
|
||||||
|
cert.CertType = ssh.UserCert
|
||||||
|
case SSHHostCert:
|
||||||
|
cert.CertType = ssh.HostCert
|
||||||
|
default:
|
||||||
|
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
|
||||||
|
}
|
||||||
|
cert.ValidPrincipals = o.Principals
|
||||||
|
if !o.ValidAfter.IsZero() {
|
||||||
|
cert.ValidAfter = uint64(o.ValidAfter.Time().Unix())
|
||||||
|
}
|
||||||
|
if !o.ValidBefore.IsZero() {
|
||||||
|
cert.ValidBefore = uint64(o.ValidBefore.Time().Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
||||||
|
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// match compares two SSHOptions and return an error if they don't match. It
|
||||||
|
// ignores zero values.
|
||||||
|
func (o SSHOptions) match(got SSHOptions) error {
|
||||||
|
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
|
||||||
|
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
|
||||||
|
}
|
||||||
|
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
|
||||||
|
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
|
||||||
|
}
|
||||||
|
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
|
||||||
|
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
|
||||||
|
}
|
||||||
|
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
|
||||||
|
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
|
||||||
|
// Key ID in the SSH certificate.
|
||||||
|
type sshCertificateKeyIDModifier string
|
||||||
|
|
||||||
|
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.KeyId = string(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the
|
||||||
|
// certificate type to the SSH certificate.
|
||||||
|
type sshCertificateCertTypeModifier string
|
||||||
|
|
||||||
|
func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.CertType = sshCertTypeUInt32(string(m))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the
|
||||||
|
// principals to the SSH certificate.
|
||||||
|
type sshCertificatePrincipalsModifier []string
|
||||||
|
|
||||||
|
func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidPrincipals = []string(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
|
||||||
|
// ValidAfter in the SSH certificate.
|
||||||
|
type sshCertificateValidAfterModifier uint64
|
||||||
|
|
||||||
|
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidAfter = uint64(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
|
||||||
|
// ValidBefore in the SSH certificate.
|
||||||
|
type sshCertificateValidBeforeModifier uint64
|
||||||
|
|
||||||
|
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidBefore = uint64(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateDefaultModifier implements a SSHCertificateModifier that
|
||||||
|
// modifies the certificate with the given options if they are not set.
|
||||||
|
type sshCertificateDefaultsModifier SSHOptions
|
||||||
|
|
||||||
|
// Modify implements the SSHCertificateModifier interface.
|
||||||
|
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if cert.CertType == 0 {
|
||||||
|
cert.CertType = sshCertTypeUInt32(m.CertType)
|
||||||
|
}
|
||||||
|
if len(cert.ValidPrincipals) == 0 {
|
||||||
|
cert.ValidPrincipals = m.Principals
|
||||||
|
}
|
||||||
|
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
|
||||||
|
cert.ValidAfter = uint64(m.ValidAfter.Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
|
||||||
|
cert.ValidBefore = uint64(m.ValidBefore.Unix())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
|
||||||
|
// the default extensions in an SSH certificate.
|
||||||
|
type sshDefaultExtensionModifier struct{}
|
||||||
|
|
||||||
|
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
switch cert.CertType {
|
||||||
|
// Default to no extensions for HostCert.
|
||||||
|
case ssh.HostCert:
|
||||||
|
return nil
|
||||||
|
case ssh.UserCert:
|
||||||
|
if cert.Extensions == nil {
|
||||||
|
cert.Extensions = make(map[string]string)
|
||||||
|
}
|
||||||
|
cert.Extensions["permit-X11-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-agent-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-port-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-pty"] = ""
|
||||||
|
cert.Extensions["permit-user-rc"] = ""
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("ssh certificate type has not been set or is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidityModifier is a SSHCertificateModifier checks the
|
||||||
|
// validity bounds, setting them if they are not provided. It will fail if a
|
||||||
|
// CertType has not been set or is not valid.
|
||||||
|
type sshCertificateValidityModifier struct {
|
||||||
|
*Claimer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
var d, min, max time.Duration
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
d = m.DefaultUserSSHCertDuration()
|
||||||
|
min = m.MinUserSSHCertDuration()
|
||||||
|
max = m.MaxUserSSHCertDuration()
|
||||||
|
case ssh.HostCert:
|
||||||
|
d = m.DefaultHostSSHCertDuration()
|
||||||
|
min = m.MinHostSSHCertDuration()
|
||||||
|
max = m.MaxHostSSHCertDuration()
|
||||||
|
case 0:
|
||||||
|
return errors.New("ssh certificate type has not been set")
|
||||||
|
default:
|
||||||
|
return errors.Errorf("unknown ssh certificate type %d", cert.CertType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cert.ValidAfter == 0 {
|
||||||
|
cert.ValidAfter = uint64(now().Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 {
|
||||||
|
t := time.Unix(int64(cert.ValidAfter), 0)
|
||||||
|
cert.ValidBefore = uint64(t.Add(d).Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
|
||||||
|
switch {
|
||||||
|
case diff < min:
|
||||||
|
return errors.Errorf("ssh certificate duration cannot be lower than %s", min)
|
||||||
|
case diff > max:
|
||||||
|
return errors.Errorf("ssh certificate duration cannot be greater than %s", max)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
|
||||||
|
// usually present in the token.
|
||||||
|
type sshCertificateOptionsValidator SSHOptions
|
||||||
|
|
||||||
|
// Valid implements SSHCertificateOptionsValidator and returns nil if both
|
||||||
|
// SSHOptions match.
|
||||||
|
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
|
||||||
|
want := SSHOptions(v)
|
||||||
|
return want.match(got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateDefaultValidator implements a simple validator for all the
|
||||||
|
// fields in the SSH certificate.
|
||||||
|
type sshCertificateDefaultValidator struct{}
|
||||||
|
|
||||||
|
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||||
|
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
||||||
|
switch {
|
||||||
|
case len(cert.Nonce) == 0:
|
||||||
|
return errors.New("ssh certificate nonce cannot be empty")
|
||||||
|
case cert.Key == nil:
|
||||||
|
return errors.New("ssh certificate key cannot be nil")
|
||||||
|
case cert.Serial == 0:
|
||||||
|
return errors.New("ssh certificate serial cannot be 0")
|
||||||
|
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||||
|
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType)
|
||||||
|
case cert.KeyId == "":
|
||||||
|
return errors.New("ssh certificate key id cannot be empty")
|
||||||
|
case len(cert.ValidPrincipals) == 0:
|
||||||
|
return errors.New("ssh certificate valid principals cannot be empty")
|
||||||
|
case cert.ValidAfter == 0:
|
||||||
|
return errors.New("ssh certificate valid after cannot be 0")
|
||||||
|
case cert.ValidBefore == 0:
|
||||||
|
return errors.New("ssh certificate valid before cannot be 0")
|
||||||
|
case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0:
|
||||||
|
return errors.New("ssh certificate extensions cannot be empty")
|
||||||
|
case cert.SignatureKey == nil:
|
||||||
|
return errors.New("ssh certificate signature key cannot be nil")
|
||||||
|
case cert.Signature == nil:
|
||||||
|
return errors.New("ssh certificate signature cannot be nil")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertTypeUInt32
|
||||||
|
func sshCertTypeUInt32(ct string) uint32 {
|
||||||
|
switch ct {
|
||||||
|
case SSHUserCert:
|
||||||
|
return ssh.UserCert
|
||||||
|
case SSHHostCert:
|
||||||
|
return ssh.HostCert
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsAllMembers reports whether all members of subgroup are within group.
|
||||||
|
func containsAllMembers(group, subgroup []string) bool {
|
||||||
|
lg, lsg := len(group), len(subgroup)
|
||||||
|
if lsg > lg || (lg > 0 && lsg == 0) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
visit := make(map[string]struct{}, lg)
|
||||||
|
for i := 0; i < lg; i++ {
|
||||||
|
visit[group[i]] = struct{}{}
|
||||||
|
}
|
||||||
|
for i := 0; i < lsg; i++ {
|
||||||
|
if _, ok := visit[subgroup[i]]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
@ -0,0 +1,125 @@
|
|||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validateSSHCertificate(cert *ssh.Certificate, opts *SSHOptions) error {
|
||||||
|
switch {
|
||||||
|
case cert == nil:
|
||||||
|
return fmt.Errorf("certificate is nil")
|
||||||
|
case cert.Signature == nil:
|
||||||
|
return fmt.Errorf("certificate signature is nil")
|
||||||
|
case cert.SignatureKey == nil:
|
||||||
|
return fmt.Errorf("certificate signature is nil")
|
||||||
|
case !reflect.DeepEqual(cert.ValidPrincipals, opts.Principals):
|
||||||
|
return fmt.Errorf("certificate principals are not equal, want %v, got %v", opts.Principals, cert.ValidPrincipals)
|
||||||
|
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||||
|
return fmt.Errorf("certificate type %v is not valid", cert.CertType)
|
||||||
|
case opts.CertType == "user" && cert.CertType != ssh.UserCert:
|
||||||
|
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.UserCert, cert.CertType)
|
||||||
|
case opts.CertType == "host" && cert.CertType != ssh.HostCert:
|
||||||
|
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType)
|
||||||
|
case cert.ValidAfter != uint64(opts.ValidAfter.Unix()):
|
||||||
|
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
|
||||||
|
case cert.ValidBefore != uint64(opts.ValidBefore.Unix()):
|
||||||
|
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
|
||||||
|
case opts.CertType == "user" && len(cert.Extensions) != 5:
|
||||||
|
return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions))
|
||||||
|
case opts.CertType == "host" && len(cert.Extensions) != 0:
|
||||||
|
return fmt.Errorf("certificate extensions number is invalid, want 0, got %d", len(cert.Extensions))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOption, signKey crypto.Signer) (*ssh.Certificate, error) {
|
||||||
|
pub, err := ssh.NewPublicKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var mods []SSHCertificateModifier
|
||||||
|
var validators []SSHCertificateValidator
|
||||||
|
|
||||||
|
for _, op := range signOpts {
|
||||||
|
switch o := op.(type) {
|
||||||
|
// modify the ssh.Certificate
|
||||||
|
case SSHCertificateModifier:
|
||||||
|
mods = append(mods, o)
|
||||||
|
// modify the ssh.Certificate given the SSHOptions
|
||||||
|
case SSHCertificateOptionModifier:
|
||||||
|
mods = append(mods, o.Option(opts))
|
||||||
|
// validate the ssh.Certificate
|
||||||
|
case SSHCertificateValidator:
|
||||||
|
validators = append(validators, o)
|
||||||
|
// validate the given SSHOptions
|
||||||
|
case SSHCertificateOptionsValidator:
|
||||||
|
if err := o.Valid(opts); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("signSSH: invalid extra option type %T", o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build base certificate with the key and some random values
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0},
|
||||||
|
Key: pub,
|
||||||
|
Serial: 1234567890,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use opts to modify the certificate
|
||||||
|
if err := opts.Modify(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use provisioner modifiers
|
||||||
|
for _, m := range mods {
|
||||||
|
if err := m.Modify(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get signer from authority keys
|
||||||
|
var signer ssh.Signer
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
signer, err = ssh.NewSignerFromSigner(signKey)
|
||||||
|
case ssh.HostCert:
|
||||||
|
signer, err = ssh.NewSignerFromSigner(signKey)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected ssh certificate type: %d", cert.CertType)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signer.PublicKey()
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
|
||||||
|
// User provisioners validators
|
||||||
|
for _, v := range validators {
|
||||||
|
if err := v.Valid(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
@ -0,0 +1,239 @@
|
|||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SSHAddUserPrincipal is the principal that will run the add user command.
|
||||||
|
// Defaults to "provisioner" but it can be changed in the configuration.
|
||||||
|
SSHAddUserPrincipal = "provisioner"
|
||||||
|
|
||||||
|
// SSHAddUserCommand is the default command to run to add a new user.
|
||||||
|
// Defaults to "sudo useradd -m <principal>; nc -q0 localhost 22" but it can be changed in the
|
||||||
|
// configuration. The string "<principal>" will be replace by the new
|
||||||
|
// principal to add.
|
||||||
|
SSHAddUserCommand = "sudo useradd -m <principal>; nc -q0 localhost 22"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||||
|
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
var mods []provisioner.SSHCertificateModifier
|
||||||
|
var validators []provisioner.SSHCertificateValidator
|
||||||
|
|
||||||
|
for _, op := range signOpts {
|
||||||
|
switch o := op.(type) {
|
||||||
|
// modify the ssh.Certificate
|
||||||
|
case provisioner.SSHCertificateModifier:
|
||||||
|
mods = append(mods, o)
|
||||||
|
// modify the ssh.Certificate given the SSHOptions
|
||||||
|
case provisioner.SSHCertificateOptionModifier:
|
||||||
|
mods = append(mods, o.Option(opts))
|
||||||
|
// validate the ssh.Certificate
|
||||||
|
case provisioner.SSHCertificateValidator:
|
||||||
|
validators = append(validators, o)
|
||||||
|
// validate the given SSHOptions
|
||||||
|
case provisioner.SSHCertificateOptionsValidator:
|
||||||
|
if err := o.Valid(opts); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Errorf("signSSH: invalid extra option type %T", o),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := randutil.ASCII(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
|
||||||
|
var serial uint64
|
||||||
|
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error reading random number"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build base certificate with the key and some random values
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte(nonce),
|
||||||
|
Key: key,
|
||||||
|
Serial: serial,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use opts to modify the certificate
|
||||||
|
if err := opts.Modify(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use provisioner modifiers
|
||||||
|
for _, m := range mods {
|
||||||
|
if err := m.Modify(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get signer from authority keys
|
||||||
|
var signer ssh.Signer
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
if a.sshCAUserCertSignKey == nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSH: user certificate signing is not enabled"),
|
||||||
|
code: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error creating signer"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ssh.HostCert:
|
||||||
|
if a.sshCAHostCertSignKey == nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSH: host certificate signing is not enabled"),
|
||||||
|
code: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error creating signer"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signer.PublicKey()
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error signing certificate"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
|
||||||
|
// User provisioners validators
|
||||||
|
for _, v := range validators {
|
||||||
|
if err := v.Valid(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
||||||
|
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
|
if a.sshCAUserCertSignKey == nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
|
||||||
|
code: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if subject.CertType != ssh.UserCert {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSHProxy: certificate is not a user certificate"),
|
||||||
|
code: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(subject.ValidPrincipals) != 1 {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSHProxy: certificate does not have only one principal"),
|
||||||
|
code: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := randutil.ASCII(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
|
||||||
|
var serial uint64
|
||||||
|
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSHProxy: error reading random number"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ssh.NewSignerFromSigner(a.sshCAUserCertSignKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSHProxy: error creating signer"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
principal := subject.ValidPrincipals[0]
|
||||||
|
addUserPrincipal := a.getAddUserPrincipal()
|
||||||
|
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte(nonce),
|
||||||
|
Key: key,
|
||||||
|
Serial: serial,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: principal + "-" + addUserPrincipal,
|
||||||
|
ValidPrincipals: []string{addUserPrincipal},
|
||||||
|
ValidAfter: subject.ValidAfter,
|
||||||
|
ValidBefore: subject.ValidBefore,
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{
|
||||||
|
"force-command": a.getAddUserCommand(principal),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SignatureKey: signer.PublicKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||||
|
if a.config.SSH.AddUserPrincipal == "" {
|
||||||
|
return SSHAddUserPrincipal
|
||||||
|
}
|
||||||
|
return a.config.SSH.AddUserPrincipal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) getAddUserCommand(principal string) string {
|
||||||
|
var cmd string
|
||||||
|
if a.config.SSH.AddUserCommand == "" {
|
||||||
|
cmd = SSHAddUserCommand
|
||||||
|
} else {
|
||||||
|
cmd = a.config.SSH.AddUserCommand
|
||||||
|
}
|
||||||
|
return strings.Replace(cmd, "<principal>", principal, -1)
|
||||||
|
}
|
@ -0,0 +1,252 @@
|
|||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sshTestModifier ssh.Certificate
|
||||||
|
|
||||||
|
func (m sshTestModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if m.CertType != 0 {
|
||||||
|
cert.CertType = m.CertType
|
||||||
|
}
|
||||||
|
if m.KeyId != "" {
|
||||||
|
cert.KeyId = m.KeyId
|
||||||
|
}
|
||||||
|
if m.ValidAfter != 0 {
|
||||||
|
cert.ValidAfter = m.ValidAfter
|
||||||
|
}
|
||||||
|
if m.ValidBefore != 0 {
|
||||||
|
cert.ValidBefore = m.ValidBefore
|
||||||
|
}
|
||||||
|
if len(m.ValidPrincipals) != 0 {
|
||||||
|
cert.ValidPrincipals = m.ValidPrincipals
|
||||||
|
}
|
||||||
|
if m.Permissions.CriticalOptions != nil {
|
||||||
|
cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions
|
||||||
|
}
|
||||||
|
if m.Permissions.Extensions != nil {
|
||||||
|
cert.Permissions.Extensions = m.Permissions.Extensions
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestCertModifier string
|
||||||
|
|
||||||
|
func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if m == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf(string(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestCertValidator string
|
||||||
|
|
||||||
|
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf(string(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestOptionsValidator string
|
||||||
|
|
||||||
|
func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf(string(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestOptionsModifier string
|
||||||
|
|
||||||
|
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
|
||||||
|
return sshTestCertModifier(string(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthority_SignSSH(t *testing.T) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
pub, err := ssh.NewPublicKey(key.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userOptions := sshTestModifier{
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
}
|
||||||
|
hostOptions := sshTestModifier{
|
||||||
|
CertType: ssh.HostCert,
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
sshCAUserCertSignKey crypto.Signer
|
||||||
|
sshCAHostCertSignKey crypto.Signer
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
key ssh.PublicKey
|
||||||
|
opts provisioner.SSHOptions
|
||||||
|
signOpts []provisioner.SignOption
|
||||||
|
}
|
||||||
|
type want struct {
|
||||||
|
CertType uint32
|
||||||
|
Principals []string
|
||||||
|
ValidAfter uint64
|
||||||
|
ValidBefore uint64
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want want
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{hostOptions}}, want{CertType: ssh.HostCert}, false},
|
||||||
|
{"ok-opts-type-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-opts-type-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert}, false},
|
||||||
|
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
|
||||||
|
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
|
||||||
|
{"ok-opts-valid-after", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false},
|
||||||
|
{"ok-opts-valid-before", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false},
|
||||||
|
{"ok-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"fail-opts-type", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "foo"}, []provisioner.SignOption{}}, want{}, true},
|
||||||
|
{"fail-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("an error")}}, want{}, true},
|
||||||
|
{"fail-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("an error")}}, want{}, true},
|
||||||
|
{"fail-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("an error")}}, want{}, true},
|
||||||
|
{"fail-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("an error")}}, want{}, true},
|
||||||
|
{"fail-bad-sign-options", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, "wrong type"}}, want{}, true},
|
||||||
|
{"fail-no-user-key", fields{nil, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{}, true},
|
||||||
|
{"fail-no-host-key", fields{signKey, nil}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{}, true},
|
||||||
|
{"fail-bad-type", fields{signKey, nil}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{sshTestModifier{CertType: 0}}}, want{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := testAuthority(t)
|
||||||
|
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||||
|
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||||
|
|
||||||
|
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && assert.NotNil(t, got) {
|
||||||
|
assert.Equals(t, tt.want.CertType, got.CertType)
|
||||||
|
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
|
||||||
|
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
|
||||||
|
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
|
||||||
|
assert.NotNil(t, got.Key)
|
||||||
|
assert.NotNil(t, got.Nonce)
|
||||||
|
assert.NotEquals(t, 0, got.Serial)
|
||||||
|
assert.NotNil(t, got.Signature)
|
||||||
|
assert.NotNil(t, got.SignatureKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthority_SignSSHAddUser(t *testing.T) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
pub, err := ssh.NewPublicKey(key.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
sshCAUserCertSignKey crypto.Signer
|
||||||
|
sshCAHostCertSignKey crypto.Signer
|
||||||
|
addUserPrincipal string
|
||||||
|
addUserCommand string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
key ssh.PublicKey
|
||||||
|
subject *ssh.Certificate
|
||||||
|
}
|
||||||
|
type want struct {
|
||||||
|
CertType uint32
|
||||||
|
Principals []string
|
||||||
|
ValidAfter uint64
|
||||||
|
ValidBefore uint64
|
||||||
|
ForceCommand string
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
validCert := &ssh.Certificate{
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
ValidPrincipals: []string{"user"},
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
validWant := want{
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
Principals: []string{"provisioner"},
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
ForceCommand: "sudo useradd -m user; nc -q0 localhost 22",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want want
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{signKey, signKey, "", ""}, args{pub, validCert}, validWant, false},
|
||||||
|
{"ok-no-host-key", fields{signKey, nil, "", ""}, args{pub, validCert}, validWant, false},
|
||||||
|
{"ok-custom-principal", fields{signKey, signKey, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false},
|
||||||
|
{"ok-custom-command", fields{signKey, signKey, "", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false},
|
||||||
|
{"ok-custom-principal-and-command", fields{signKey, signKey, "my-principal", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false},
|
||||||
|
{"fail-no-user-key", fields{nil, signKey, "", ""}, args{pub, validCert}, want{}, true},
|
||||||
|
{"fail-no-user-cert", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true},
|
||||||
|
{"fail-no-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true},
|
||||||
|
{"fail-many-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := testAuthority(t)
|
||||||
|
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||||
|
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||||
|
a.config.SSH = &SSHConfig{
|
||||||
|
AddUserPrincipal: tt.fields.addUserPrincipal,
|
||||||
|
AddUserCommand: tt.fields.addUserCommand,
|
||||||
|
}
|
||||||
|
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && assert.NotNil(t, got) {
|
||||||
|
assert.Equals(t, tt.want.CertType, got.CertType)
|
||||||
|
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
|
||||||
|
assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId)
|
||||||
|
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
|
||||||
|
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
|
||||||
|
assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions)
|
||||||
|
assert.Equals(t, nil, got.Extensions)
|
||||||
|
assert.NotNil(t, got.Key)
|
||||||
|
assert.NotNil(t, got.Nonce)
|
||||||
|
assert.NotEquals(t, 0, got.Serial)
|
||||||
|
assert.NotNil(t, got.Signature)
|
||||||
|
assert.NotNil(t, got.SignatureKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue