Merge pull request #85 from smallstep/ssh-ca

SSH CA support
pull/102/head
Max 5 years ago committed by GitHub
commit 7726f5ec75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

8
Gopkg.lock generated

@ -262,15 +262,20 @@
[[projects]]
branch = "master"
digest = "1:5dd7da6df07f42194cb25d162b4b89664ed7b08d7d4334f6a288393d54b095ce"
digest = "1:afc49fe39c8c591fc2c8ddc73adc4c69e67125dde6c58e24c91b3b0cf78602be"
name = "golang.org/x/crypto"
packages = [
"cryptobyte",
"cryptobyte/asn1",
"curve25519",
"ed25519",
"ed25519/internal/edwards25519",
"internal/chacha20",
"internal/subtle",
"ocsp",
"pbkdf2",
"poly1305",
"ssh",
"ssh/terminal",
]
pruneopts = "UT"
@ -394,6 +399,7 @@
"github.com/urfave/cli",
"golang.org/x/crypto/ed25519",
"golang.org/x/crypto/ocsp",
"golang.org/x/crypto/ssh",
"golang.org/x/net/http2",
"gopkg.in/square/go-jose.v2",
"gopkg.in/square/go-jose.v2/jwt",

@ -1,6 +1,7 @@
package api
import (
"context"
"crypto/dsa"
"crypto/ecdsa"
"crypto/rsa"
@ -26,9 +27,10 @@ import (
// Authority is the interface implemented by a CA authority.
type Authority interface {
SSHAuthority
// NOTE: Authorize will be deprecated in future releases. Please use the
// context specific Authoirize[Sign|Revoke|etc.] methods.
Authorize(ott string) ([]provisioner.SignOption, error)
// context specific Authorize[Sign|Revoke|etc.] methods.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
GetTLSOptions() *tlsutil.TLSOptions
Root(shasum string) (*x509.Certificate, error)
@ -249,6 +251,8 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("GET", "/federation", h.Federation)
// For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew)
// SSH CA
r.MethodFunc("POST", "/sign-ssh", h.SignSSH)
}
// Health is an HTTP handler that returns the status of the server.

@ -31,6 +31,7 @@ import (
"github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
)
const (
@ -424,7 +425,7 @@ type mockProvisioner struct {
getEncryptedKey func() (string, string, bool)
init func(provisioner.Config) error
authorizeRevoke func(ott string) error
authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
authorizeRenewal func(*x509.Certificate) error
}
@ -480,9 +481,9 @@ func (m *mockProvisioner) AuthorizeRevoke(ott string) error {
return m.err
}
func (m *mockProvisioner) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
if m.authorizeSign != nil {
return m.authorizeSign(ott)
return m.authorizeSign(ctx, ott)
}
return m.ret1.([]provisioner.SignOption), m.err
}
@ -501,6 +502,8 @@ type mockAuthority struct {
getTLSOptions func() *tlsutil.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
@ -511,7 +514,7 @@ type mockAuthority struct {
}
// TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return m.AuthorizeSign(ott)
}
@ -543,6 +546,20 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
}
func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
if m.signSSH != nil {
return m.signSSH(key, opts, signOpts...)
}
return m.ret1.(*ssh.Certificate), m.err
}
func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
if m.signSSHAddUser != nil {
return m.signSSHAddUser(key, cert)
}
return m.ret1.(*ssh.Certificate), m.err
}
func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
if m.renew != nil {
return m.renew(cert)

@ -0,0 +1,159 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"golang.org/x/crypto/ssh"
)
// SSHAuthority is the interface implemented by a SSH CA authority.
type SSHAuthority interface {
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
}
// SignSSHRequest is the request body of an SSH certificate request.
type SignSSHRequest struct {
PublicKey []byte `json:"publicKey"` //base64 encoded
OTT string `json:"ott"`
CertType string `json:"certType,omitempty"`
Principals []string `json:"principals,omitempty"`
ValidAfter TimeDuration `json:"validAfter,omitempty"`
ValidBefore TimeDuration `json:"validBefore,omitempty"`
AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"`
}
// SignSSHResponse is the response object that returns the SSH certificate.
type SignSSHResponse struct {
Certificate SSHCertificate `json:"crt"`
AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"`
}
// SSHCertificate represents the response SSH certificate.
type SSHCertificate struct {
*ssh.Certificate `json:"omitempty"`
}
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
// base64 encoded, openssh wire format version of the certificate.
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
if c.Certificate == nil {
return []byte("null"), nil
}
s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal())
return []byte(`"` + s + `"`), nil
}
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is
// expected to be a quoted, base64 encoded, openssh wire formatted block of bytes.
func (c *SSHCertificate) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "error decoding certificate")
}
if s == "" {
c.Certificate = nil
return nil
}
certData, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrap(err, "error decoding ssh certificate")
}
pub, err := ssh.ParsePublicKey(certData)
if err != nil {
return errors.Wrap(err, "error parsing ssh certificate")
}
cert, ok := pub.(*ssh.Certificate)
if !ok {
return errors.Errorf("error decoding ssh certificate: %T is not an *ssh.Certificate", pub)
}
c.Certificate = cert
return nil
}
// Validate validates the SignSSHRequest.
func (s *SignSSHRequest) Validate() error {
switch {
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
return errors.Errorf("unknown certType %s", s.CertType)
case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey")
case len(s.OTT) == 0:
return errors.New("missing or empty ott")
default:
return nil
}
}
// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in
// the request.
func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) {
var body SignSSHRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
return
}
var addUserPublicKey ssh.PublicKey
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
return
}
}
opts := provisioner.SSHOptions{
CertType: body.CertType,
Principals: body.Principals,
ValidBefore: body.ValidBefore,
ValidAfter: body.ValidAfter,
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
return
}
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
if err != nil {
WriteError(w, Forbidden(err))
return
}
var addUserCertificate *SSHCertificate
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
if err != nil {
WriteError(w, Forbidden(err))
return
}
addUserCertificate = &SSHCertificate{addUserCert}
}
w.WriteHeader(http.StatusCreated)
JSON(w, &SignSSHResponse{
Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate,
})
}

@ -0,0 +1,327 @@
package api
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ssh"
)
var (
sshSignerKey = mustKey()
sshUserKey = mustKey()
sshHostKey = mustKey()
)
func mustKey() *ecdsa.PrivateKey {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
return priv
}
func signSSHCertificate(cert *ssh.Certificate) error {
signerKey, err := ssh.NewPublicKey(sshSignerKey.Public())
if err != nil {
return err
}
signer, err := ssh.NewSignerFromSigner(sshSignerKey)
if err != nil {
return err
}
cert.SignatureKey = signerKey
data := cert.Marshal()
data = data[:len(data)-4]
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return err
}
cert.Signature = sig
return nil
}
func getSignedUserCertificate() (*ssh.Certificate, error) {
key, err := ssh.NewPublicKey(sshUserKey.Public())
if err != nil {
return nil, err
}
t := time.Now()
cert := &ssh.Certificate{
Nonce: []byte("1234567890"),
Key: key,
Serial: 1234567890,
CertType: ssh.UserCert,
KeyId: "user@localhost",
ValidPrincipals: []string{"user"},
ValidAfter: uint64(t.Unix()),
ValidBefore: uint64(t.Add(time.Hour).Unix()),
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
},
},
Reserved: []byte{},
}
if err := signSSHCertificate(cert); err != nil {
return nil, err
}
return cert, nil
}
func getSignedHostCertificate() (*ssh.Certificate, error) {
key, err := ssh.NewPublicKey(sshHostKey.Public())
if err != nil {
return nil, err
}
t := time.Now()
cert := &ssh.Certificate{
Nonce: []byte("1234567890"),
Key: key,
Serial: 1234567890,
CertType: ssh.UserCert,
KeyId: "internal.smallstep.com",
ValidPrincipals: []string{"internal.smallstep.com"},
ValidAfter: uint64(t.Unix()),
ValidBefore: uint64(t.Add(time.Hour).Unix()),
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
Reserved: []byte{},
}
if err := signSSHCertificate(cert); err != nil {
return nil, err
}
return cert, nil
}
func TestSSHCertificate_MarshalJSON(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
host, err := getSignedHostCertificate()
assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
type fields struct {
Certificate *ssh.Certificate
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"nil", fields{Certificate: nil}, []byte("null"), false},
{"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false},
{"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := SSHCertificate{
Certificate: tt.fields.Certificate,
}
got, err := c.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want)
}
})
}
}
func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
host, err := getSignedHostCertificate()
assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal())
type args struct {
data []byte
}
tests := []struct {
name string
args args
want *ssh.Certificate
wantErr bool
}{
{"null", args{[]byte(`null`)}, nil, false},
{"empty", args{[]byte(`""`)}, nil, false},
{"user", args{[]byte(`"` + userB64 + `"`)}, user, false},
{"host", args{[]byte(`"` + hostB64 + `"`)}, host, false},
{"bad-string", args{[]byte(userB64)}, nil, true},
{"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true},
{"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true},
{"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &SSHCertificate{}
if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(tt.want, c.Certificate) {
t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want)
}
})
}
}
func TestSignSSHRequest_Validate(t *testing.T) {
type fields struct {
PublicKey []byte
OTT string
CertType string
Principals []string
ValidAfter TimeDuration
ValidBefore TimeDuration
AddUserPublicKey []byte
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &SignSSHRequest{
PublicKey: tt.fields.PublicKey,
OTT: tt.fields.OTT,
CertType: tt.fields.CertType,
Principals: tt.fields.Principals,
ValidAfter: tt.fields.ValidAfter,
ValidBefore: tt.fields.ValidBefore,
AddUserPublicKey: tt.fields.AddUserPublicKey,
}
if err := s.Validate(); (err != nil) != tt.wantErr {
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_caHandler_SignSSH(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
host, err := getSignedHostCertificate()
assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
userReq, err := json.Marshal(SignSSHRequest{
PublicKey: user.Key.Marshal(),
OTT: "ott",
})
assert.FatalError(t, err)
hostReq, err := json.Marshal(SignSSHRequest{
PublicKey: host.Key.Marshal(),
OTT: "ott",
})
assert.FatalError(t, err)
userAddReq, err := json.Marshal(SignSSHRequest{
PublicKey: user.Key.Marshal(),
OTT: "ott",
AddUserPublicKey: user.Key.Marshal(),
})
assert.FatalError(t, err)
type fields struct {
Authority Authority
}
type args struct {
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
req []byte
authErr error
signCert *ssh.Certificate
signErr error
addUserCert *ssh.Certificate
addUserErr error
body []byte
statusCode int
}{
{"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
{"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
{"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized},
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
return []provisioner.SignOption{}, tt.authErr
},
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
return tt.signCert, tt.signErr
},
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
return tt.addUserCert, tt.addUserErr
},
}).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req))
w := httptest.NewRecorder()
h.SignSSH(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body)
}
}
})
}
}

@ -1,12 +1,14 @@
package authority
import (
"crypto"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"sync"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/crypto/pemutil"
@ -20,6 +22,8 @@ type Authority struct {
config *Config
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
sshCAUserCertSignKey crypto.Signer
sshCAHostCertSignKey crypto.Signer
validateOnce bool
certificates *sync.Map
startTime time.Time
@ -117,6 +121,22 @@ func (a *Authority) init() error {
}
}
// Decrypt and load SSH keys
if a.config.SSH != nil {
if a.config.SSH.HostKey != "" {
a.sshCAHostCertSignKey, err = parseCryptoSigner(a.config.SSH.HostKey, a.config.Password)
if err != nil {
return err
}
}
if a.config.SSH.UserKey != "" {
a.sshCAUserCertSignKey, err = parseCryptoSigner(a.config.SSH.UserKey, a.config.Password)
if err != nil {
return err
}
}
}
// Store all the provisioners
for _, p := range a.config.AuthorityConfig.Provisioners {
if err := a.provisioners.Store(p); err != nil {
@ -143,3 +163,19 @@ func (a *Authority) GetDatabase() db.AuthDB {
func (a *Authority) Shutdown() error {
return a.db.Shutdown()
}
func parseCryptoSigner(filename, password string) (crypto.Signer, error) {
var opts []pemutil.Options
if password != "" {
opts = append(opts, pemutil.WithPassword([]byte(password)))
}
key, err := pemutil.Read(filename, opts...)
if err != nil {
return nil, err
}
signer, ok := key.(crypto.Signer)
if !ok {
return nil, errors.Errorf("key %s of type %T cannot be used for signing operations", filename, key)
}
return signer, nil
}

@ -1,6 +1,7 @@
package authority
import (
"context"
"crypto/x509"
"net/http"
"strings"
@ -72,33 +73,51 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
return p, nil
}
// Authorize is a passthrough to AuthorizeSign.
// NOTE: Authorize will be deprecated in a future release. Please use the
// context specific Authorize[Sign|Revoke|etc.] going forwards.
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
return a.AuthorizeSign(ott)
// Authorize grabs the method from the context and authorizes a signature
// request by validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod:
return a.authorizeSign(ctx, ott)
case provisioner.SignSSHMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
}
return a.authorizeSign(ctx, ott)
case provisioner.RevokeMethod:
return nil, &apiError{errors.New("authorize: revoke method is not supported"), http.StatusInternalServerError, errContext}
default:
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
}
}
// AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
var errContext = context{"ott": ott}
// authorizeSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSign method. Returns a
// list of methods to apply to the signing flow.
func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
p, err := a.authorizeToken(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
// Call the provisioner AuthorizeSign method to apply provisioner specific
// auth claims and get the signing options.
opts, err := p.AuthorizeSign(ott)
opts, err := p.AuthorizeSign(ctx, ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
return opts, nil
}
// AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
//
// NOTE: This method is deprecated and should not be used. We make it available
// in the short term os as not to break existing clients.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
return a.Authorize(ctx, ott)
}
// authorizeRevoke authorizes a revocation request by validating and authenticating
// the RevokeOptions POSTed with the request.
// Returns a tuple of the provisioner ID and error, if one occurred.

@ -1,11 +1,14 @@
package authority
import (
"context"
"crypto/x509"
"net/http"
"testing"
"time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/pemutil"
@ -72,7 +75,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
}
},
"fail/prehistoric-token": func(t *testing.T) *authorizeTest {
@ -91,7 +94,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
@ -113,7 +116,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok/simpledb": func(t *testing.T) *authorizeTest {
@ -150,7 +153,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a,
ott: raw,
err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok/mockNoSQLDB": func(t *testing.T) *authorizeTest {
@ -198,7 +201,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a,
ott: raw,
err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"),
http.StatusInternalServerError, context{"ott": raw}},
http.StatusInternalServerError, apiCtx{"ott": raw}},
}
},
"fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest {
@ -223,7 +226,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
auth: _a,
ott: raw,
err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
}
@ -388,7 +391,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
}
},
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
@ -406,7 +409,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok": func(t *testing.T) *authorizeTest {
@ -480,7 +483,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
}
},
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
@ -498,7 +501,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}},
http.StatusUnauthorized, apiCtx{"ott": raw}},
}
},
"ok": func(t *testing.T) *authorizeTest {
@ -522,8 +525,8 @@ func TestAuthority_Authorize(t *testing.T) {
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
got, err := tc.auth.Authorize(tc.ott)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
got, err := tc.auth.Authorize(ctx, tc.ott)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Nil(t, got)
@ -573,7 +576,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: fooCrt,
err: &apiError{errors.New("renew: force"),
http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}},
http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
}
},
"fail/revoked": func(t *testing.T) *authorizeTest {
@ -587,7 +590,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: fooCrt,
err: &apiError{errors.New("renew: certificate has been revoked"),
http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}},
http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
}
},
"fail/load-provisioner": func(t *testing.T) *authorizeTest {
@ -601,7 +604,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: otherCrt,
err: &apiError{errors.New("renew: provisioner not found"),
http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}},
http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}},
}
},
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
@ -616,7 +619,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
auth: a,
crt: renewDisabledCrt,
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}},
http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}},
}
},
"ok": func(t *testing.T) *authorizeTest {

@ -28,11 +28,19 @@ var (
Renegotiation: false,
}
defaultDisableRenewal = false
defaultEnableSSHCA = false
globalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 4 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA,
}
)
@ -44,6 +52,7 @@ type Config struct {
IntermediateKey string `json:"key"`
Address string `json:"address"`
DNSNames []string `json:"dnsNames"`
SSH *SSHConfig `json:"ssh,omitempty"`
Logger json.RawMessage `json:"logger,omitempty"`
DB *db.Config `json:"db,omitempty"`
Monitoring json.RawMessage `json:"monitoring,omitempty"`
@ -92,6 +101,14 @@ func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
return nil
}
// SSHConfig contains the user and host keys.
type SSHConfig struct {
HostKey string `json:"hostKey"`
UserKey string `json:"userKey"`
AddUserPrincipal string `json:"addUserPrincipal"`
AddUserCommand string `json:"addUserCommand"`
}
// LoadConfiguration parses the given filename in JSON format and returns the
// configuration struct.
func LoadConfiguration(filename string) (*Config, error) {

@ -4,13 +4,13 @@ import (
"net/http"
)
type context map[string]interface{}
type apiCtx map[string]interface{}
// Error implements the api.Error interface and adds context to error messages.
type apiError struct {
err error
code int
context context
context apiCtx
}
// Cause implements the errors.Causer interface and returns the original error.

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
@ -266,13 +267,21 @@ func (p *AWS) Init(config Config) (err error) {
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
payload, err := p.authorizeToken(token)
if err != nil {
return nil, err
}
doc := payload.document
// Check for the sign ssh method, default to sign X.509
if m := MethodFromContext(ctx); m == SignSSHMethod {
if p.claimer.IsSSHCAEnabled() == false {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
}
return p.authorizeSSHSign(payload)
}
doc := payload.document
// Enforce known CN and default DNS and IP if configured.
// By default we'll accept the CN and SANs in the CSR.
// There's no way to trust them other than TOFU.
@ -433,3 +442,35 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
payload.document = doc
return &payload, nil
}
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) authorizeSSHSign(claims *awsPayload) ([]SignOption, error) {
doc := claims.document
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(claims.Subject),
}
// Default to host + known IPs/hostnames
defaults := SSHOptions{
CertType: SSHHostCert,
Principals: []string{
doc.PrivateIP,
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
return append(signOptions,
// set the default extensions
&sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer},
// require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
), nil
}

@ -1,6 +1,8 @@
package provisioner
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
@ -347,7 +349,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.aws.AuthorizeSign(tt.args.token)
ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
@ -357,6 +360,84 @@ func TestAWS_AuthorizeSign(t *testing.T) {
}
}
func TestAWS_AuthorizeSign_SSH(t *testing.T) {
tm, fn := mockNow()
defer fn()
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
assert.FatalError(t, err)
key, err := generateJSONWebKey()
assert.FatalError(t, err)
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
expectedHostOptionsIP := &SSHOptions{
CertType: "host", Principals: []string{"127.0.0.1"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
expectedHostOptionsHostname := &SSHOptions{
CertType: "host", Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
type args struct {
token string
sshOpts SSHOptions
}
tests := []struct {
name string
aws *AWS
args args
expected *SSHOptions
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}}, expectedHostOptionsIP, false, false},
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptionsHostname, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}}, nil, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {
assert.Nil(t, cert)
} else {
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
}
}
}
})
}
}
func TestAWS_AuthorizeRenewal(t *testing.T) {
p1, err := generateAWS()
assert.FatalError(t, err)

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
@ -209,7 +210,7 @@ func (p *Azure) Init(config Config) (err error) {
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
@ -264,6 +265,14 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
}
}
// Check for the sign ssh method, default to sign X.509
if m := MethodFromContext(ctx); m == SignSSHMethod {
if p.claimer.IsSSHCAEnabled() == false {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
}
return p.authorizeSSHSign(claims, name)
}
// Enforce known common name and default DNS if configured.
// By default we'll accept the CN and SANs in the CSR.
// There's no way to trust them other than TOFU.
@ -296,6 +305,33 @@ func (p *Azure) AuthorizeRevoke(token string) error {
return errors.New("revoke is not supported on a Azure provisioner")
}
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) authorizeSSHSign(claims azurePayload, name string) ([]SignOption, error) {
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(name),
}
// Default to host + known hostnames
defaults := SSHOptions{
CertType: SSHHostCert,
Principals: []string{name},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
return append(signOptions,
// set the default extensions
&sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer},
// require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
), nil
}
// assertConfig initializes the config if it has not been initialized
func (p *Azure) assertConfig() {
if p.config == nil {

@ -1,6 +1,8 @@
package provisioner
import (
"context"
"crypto"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
@ -295,7 +297,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.azure.AuthorizeSign(tt.args.token)
ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
@ -305,6 +308,75 @@ func TestAzure_AuthorizeSign(t *testing.T) {
}
}
func TestAzure_AuthorizeSign_SSH(t *testing.T) {
tm, fn := mockNow()
defer fn()
p1, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
t1, err := p1.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
key, err := generateJSONWebKey()
assert.FatalError(t, err)
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"virtualMachine"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
type args struct {
token string
sshOpts SSHOptions
}
tests := []struct {
name string
azure *Azure
args args
expected *SSHOptions
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}}, nil, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {
assert.Nil(t, cert)
} else {
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
}
}
}
})
}
}
func TestAzure_AuthorizeRenewal(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)

@ -8,10 +8,19 @@ import (
// Claims so that individual provisioners can override global claims.
type Claims struct {
// TLS CA properties
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
// SSH CA properties
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"`
MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"`
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
}
// Claimer is the type that controls claims. It provides an interface around the
@ -30,11 +39,19 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal()
enableSSHCA := c.IsSSHCAEnabled()
return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal,
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal,
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
}
}
@ -78,6 +95,76 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal
}
// DefaultUserSSHCertDuration returns the default SSH user cert duration for the
// provisioner. If the default is not set within the provisioner, then the
// global default from the authority configuration will be used.
func (c *Claimer) DefaultUserSSHCertDuration() time.Duration {
if c.claims == nil || c.claims.DefaultUserSSHDur == nil {
return c.global.DefaultUserSSHDur.Duration
}
return c.claims.DefaultUserSSHDur.Duration
}
// MinUserSSHCertDuration returns the minimum SSH user cert duration for the
// provisioner. If the minimum is not set within the provisioner, then the
// global minimum from the authority configuration will be used.
func (c *Claimer) MinUserSSHCertDuration() time.Duration {
if c.claims == nil || c.claims.MinUserSSHDur == nil {
return c.global.MinUserSSHDur.Duration
}
return c.claims.MinUserSSHDur.Duration
}
// MaxUserSSHCertDuration returns the maximum SSH user cert duration for the
// provisioner. If the maximum is not set within the provisioner, then the
// global maximum from the authority configuration will be used.
func (c *Claimer) MaxUserSSHCertDuration() time.Duration {
if c.claims == nil || c.claims.MaxUserSSHDur == nil {
return c.global.MaxUserSSHDur.Duration
}
return c.claims.MaxUserSSHDur.Duration
}
// DefaultHostSSHCertDuration returns the default SSH host cert duration for the
// provisioner. If the default is not set within the provisioner, then the
// global default from the authority configuration will be used.
func (c *Claimer) DefaultHostSSHCertDuration() time.Duration {
if c.claims == nil || c.claims.DefaultHostSSHDur == nil {
return c.global.DefaultHostSSHDur.Duration
}
return c.claims.DefaultHostSSHDur.Duration
}
// MinHostSSHCertDuration returns the minimum SSH host cert duration for the
// provisioner. If the minimum is not set within the provisioner, then the
// global minimum from the authority configuration will be used.
func (c *Claimer) MinHostSSHCertDuration() time.Duration {
if c.claims == nil || c.claims.MinHostSSHDur == nil {
return c.global.MinHostSSHDur.Duration
}
return c.claims.MinHostSSHDur.Duration
}
// MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the
// provisioner. If the maximum is not set within the provisioner, then the
// global maximum from the authority configuration will be used.
func (c *Claimer) MaxHostSSHCertDuration() time.Duration {
if c.claims == nil || c.claims.MaxHostSSHDur == nil {
return c.global.MaxHostSSHDur.Duration
}
return c.claims.MaxHostSSHDur.Duration
}
// IsSSHCAEnabled returns if the SSH CA is enabled for the provisioner. If the
// property is not set within the provisioner, then the global value from the
// authority configuration will be used.
func (c *Claimer) IsSSHCAEnabled() bool {
if c.claims == nil || c.claims.EnableSSHCA == nil {
return *c.global.EnableSSHCA
}
return *c.claims.EnableSSHCA
}
// Validate validates and modifies the Claims with default values.
func (c *Claimer) Validate() error {
var (

@ -2,6 +2,7 @@ package provisioner
import (
"bytes"
"context"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
@ -205,13 +206,21 @@ func (p *GCP) Init(config Config) error {
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) {
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token)
if err != nil {
return nil, err
}
ce := claims.Google.ComputeEngine
// Check for the sign ssh method, default to sign X.509
if m := MethodFromContext(ctx); m == SignSSHMethod {
if p.claimer.IsSSHCAEnabled() == false {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
}
return p.authorizeSSHSign(claims)
}
ce := claims.Google.ComputeEngine
// Enforce known common name and default DNS if configured.
// By default we we'll accept the CN and SANs in the CSR.
// There's no way to trust them other than TOFU.
@ -345,3 +354,35 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
return &claims, nil
}
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) authorizeSSHSign(claims *gcpPayload) ([]SignOption, error) {
ce := claims.Google.ComputeEngine
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(ce.InstanceName),
}
// Default to host + known hostnames
defaults := SSHOptions{
CertType: SSHHostCert,
Principals: []string{
fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID),
fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID),
},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
return append(signOptions,
// set the default extensions
&sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer},
// require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
), nil
}

@ -1,6 +1,8 @@
package provisioner
import (
"context"
"crypto"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
@ -330,7 +332,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.gcp.AuthorizeSign(tt.args.token)
ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
@ -340,6 +343,87 @@ func TestGCP_AuthorizeSign(t *testing.T) {
}
}
func TestGCP_AuthorizeSign_SSH(t *testing.T) {
tm, fn := mockNow()
defer fn()
p1, err := generateGCP()
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
key, err := generateJSONWebKey()
assert.FatalError(t, err)
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
expectedHostOptionsPrincipal1 := &SSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
expectedHostOptionsPrincipal2 := &SSHOptions{
CertType: "host", Principals: []string{"instance-name.zone.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
type args struct {
token string
sshOpts SSHOptions
}
tests := []struct {
name string
gcp *GCP
args args
expected *SSHOptions
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}}, expectedHostOptionsPrincipal1, false, false},
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}}, expectedHostOptionsPrincipal2, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}}, nil, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {
assert.Nil(t, cert)
} else {
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
}
}
}
})
}
}
func TestGCP_AuthorizeRenewal(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/x509"
"time"
@ -12,7 +13,12 @@ import (
// jwtPayload extends jwt.Claims with step attributes.
type jwtPayload struct {
jose.Claims
SANs []string `json:"sans,omitempty"`
SANs []string `json:"sans,omitempty"`
Step *stepPayload `json:"step,omitempty"`
}
type stepPayload struct {
SSH *SSHOptions `json:"ssh,omitempty"`
}
// JWK is the default provisioner, an entity that can sign tokens necessary for
@ -129,11 +135,20 @@ func (p *JWK) AuthorizeRevoke(token string) error {
}
// AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) {
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil {
return nil, err
}
// Check for SSH token
if claims.Step != nil && claims.Step.SSH != nil {
if p.claimer.IsSSHCAEnabled() == false {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
}
return p.authorizeSSHSign(claims)
}
// NOTE: This is for backwards compatibility with older versions of cli
// and certificates. Older versions added the token subject as the only SAN
// in a CSR by default.
@ -161,3 +176,41 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
}
return nil
}
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) {
t := now()
opts := claims.Step.SSH
signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts),
// set the key id to the token subject
sshCertificateKeyIDModifier(claims.Subject),
}
// Add modifiers from custom claims
if opts.CertType != "" {
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
}
if len(opts.Principals) > 0 {
signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals))
}
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
}
if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
}
// Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
return append(signOptions,
// set the default extensions
&sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer},
// require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
), nil
}

@ -1,6 +1,8 @@
package provisioner
import (
"context"
"crypto"
"crypto/x509"
"errors"
"strings"
@ -13,11 +15,19 @@ import (
var (
defaultDisableRenewal = false
defaultEnableSSHCA = true
globalProvisionerClaims = Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA,
}
)
@ -259,7 +269,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.AuthorizeSign(tt.args.token); err != nil {
ctx := NewContextWithMethod(context.Background(), SignMethod)
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
@ -318,3 +329,201 @@ func TestJWK_AuthorizeRenewal(t *testing.T) {
})
}
}
func TestJWK_AuthorizeSign_SSH(t *testing.T) {
tm, fn := mockNow()
defer fn()
p1, err := generateJWK()
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
iss, aud := p1.Name, testAudiences.Sign[0]
t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
assert.FatalError(t, err)
t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
key, err := generateJSONWebKey()
assert.FatalError(t, err)
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SSHOptions{
CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
}
expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
type args struct {
token string
sshOpts SSHOptions
}
tests := []struct {
name string
prov *JWK
args args
expected *SSHOptions
wantErr bool
wantSignErr bool
}{
{"user", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
{"host", p1, args{t2, SSHOptions{}}, expectedHostOptions, false, false},
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
{"fail-signature", p1, args{failSig, SSHOptions{}}, nil, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {
assert.Nil(t, cert)
} else {
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
}
}
}
})
}
}
func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
tm, fn := mockNow()
defer fn()
p1, err := generateJWK()
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.Sign[0], time.Now()
key, err := generateJSONWebKey()
assert.FatalError(t, err)
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SSHOptions{
CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
}
expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
type args struct {
sub, iss, aud string
iat time.Time
tokSSHOpts *SSHOptions
userSSHOpts *SSHOptions
jwk *jose.JSONWebKey
}
tests := []struct {
name string
prov *JWK
args args
expected *SSHOptions
wantErr bool
wantSignErr bool
}{
{"ok-user", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, expectedUserOptions, false, false},
{"ok-host", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, &SSHOptions{}, jwk}, expectedHostOptions, false, false},
{"ok-user-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "user", Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false},
{"ok-host-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, jwk}, expectedHostOptions, false, false},
{"ok-user-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user"}, &SSHOptions{Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false},
{"ok-host-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{Principals: []string{"smallstep.com"}}, &SSHOptions{CertType: "host"}, jwk}, expectedHostOptions, false, false},
{"ok-user-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{
CertType: "user", Principals: []string{"name"},
}, &SSHOptions{
ValidAfter: NewTimeDuration(tm.Add(-time.Hour)),
}, jwk}, &SSHOptions{
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), ValidBefore: NewTimeDuration(tm.Add(userDuration - time.Hour)),
}, false, false},
{"ok-user-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{
CertType: "user", Principals: []string{"name"},
}, &SSHOptions{
ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
}, jwk}, &SSHOptions{
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
}, false, false},
{"ok-user-validAfter-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{
CertType: "user", Principals: []string{"name"},
}, &SSHOptions{
ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
}, jwk}, &SSHOptions{
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
}, false, false},
{"ok-user-match", p1, args{sub, iss, aud, iat, &SSHOptions{
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)),
}, &SSHOptions{
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)),
}, jwk}, &SSHOptions{
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
}, false, false},
{"fail-certType", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{CertType: "host"}, jwk}, nil, false, true},
{"fail-principals", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{Principals: []string{"root"}}, jwk}, nil, false, true},
{"fail-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm)}, &SSHOptions{ValidAfter: NewTimeDuration(tm.Add(time.Hour))}, jwk}, nil, false, true},
{"fail-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidBefore: NewTimeDuration(tm.Add(time.Hour))}, &SSHOptions{ValidBefore: NewTimeDuration(tm.Add(10 * time.Hour))}, jwk}, nil, false, true},
{"fail-subject", p1, args{"", iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
{"fail-issuer", p1, args{sub, "invalid", aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
{"fail-audience", p1, args{sub, iss, "invalid", iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
{"fail-expired", p1, args{sub, iss, aud, iat.Add(-6 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
{"fail-notBefore", p1, args{sub, iss, aud, iat.Add(5 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
assert.FatalError(t, err)
if got, err := tt.prov.AuthorizeSign(ctx, token); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
} else if !tt.wantErr && assert.NotNil(t, got) {
var opts SSHOptions
if tt.args.userSSHOpts != nil {
opts = *tt.args.userSSHOpts
}
cert, err := signSSHCertificate(key.Public().Key, opts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {
assert.Nil(t, cert)
} else {
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
}
}
}
})
}
}

@ -0,0 +1,34 @@
package provisioner
import (
"context"
)
// Method indicates the action to action that we will perform, it's used as part
// of the context in the call to authorize. It defaults to Sing.
type Method int
// The key to save the Method in the context.
type methodKey struct{}
const (
// SignMethod is the method used to sign X.509 certificates.
SignMethod Method = iota
// SignSSHMethod is the method used to sign SSH certificate.
SignSSHMethod
// RevokeMethod is the method used to revoke X.509 certificates.
RevokeMethod
)
// NewContextWithMethod creates a new context from ctx and attaches method to
// it.
func NewContextWithMethod(ctx context.Context, method Method) context.Context {
return context.WithValue(ctx, methodKey{}, method)
}
// MethodFromContext returns the Method saved in ctx. Returns Sign if the given
// context has no Method associated with it.
func MethodFromContext(ctx context.Context) Method {
m, _ := ctx.Value(methodKey{}).(Method)
return m
}

@ -1,6 +1,9 @@
package provisioner
import "crypto/x509"
import (
"context"
"crypto/x509"
)
// noop provisioners is a provisioner that accepts anything.
type noop struct{}
@ -28,7 +31,7 @@ func (p *noop) Init(config Config) error {
return nil
}
func (p *noop) AuthorizeSign(token string) ([]SignOption, error) {
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil
}

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/x509"
"testing"
@ -21,7 +22,8 @@ func Test_noop(t *testing.T) {
assert.Equals(t, "", key)
assert.Equals(t, false, ok)
sigOptions, err := p.AuthorizeSign("foo")
ctx := NewContextWithMethod(context.Background(), SignMethod)
sigOptions, err := p.AuthorizeSign(ctx, "foo")
assert.Equals(t, []SignOption{}, sigOptions)
assert.Equals(t, nil, err)
}

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/x509"
"encoding/json"
"net/http"
@ -259,12 +260,29 @@ func (o *OIDC) AuthorizeRevoke(token string) error {
}
// AuthorizeSign validates the given token.
func (o *OIDC) AuthorizeSign(token string) ([]SignOption, error) {
func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := o.authorizeToken(token)
if err != nil {
return nil, err
}
// Check for the sign ssh method, default to sign X.509
if m := MethodFromContext(ctx); m == SignSSHMethod {
if o.claimer.IsSSHCAEnabled() == false {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID())
}
return o.authorizeSSHSign(claims)
}
// Admins should be able to authorize any SAN
if o.IsAdmin(claims.Email) {
return []SignOption{
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
}, nil
}
so := []SignOption{
defaultPublicKeyValidator{},
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
@ -287,6 +305,42 @@ func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
return nil
}
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) authorizeSSHSign(claims *openIDPayload) ([]SignOption, error) {
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(claims.Email),
}
name := SanitizeSSHUserPrincipal(claims.Email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email address '%s'", name, claims.Email)
}
// Admin users will default to user + name but they can be changed by the
// user options. Non-admins are only able to sign user certificates.
defaults := SSHOptions{
CertType: SSHUserCert,
Principals: []string{name},
}
if !o.IsAdmin(claims.Email) {
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
}
// Default to a user with name as principal if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
return append(signOptions,
// set the default extensions
&sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{o.claimer},
// require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
), nil
}
func getAndDecode(uri string, v interface{}) error {
resp, err := http.Get(uri)
if err != nil {

@ -1,6 +1,8 @@
package provisioner
import (
"context"
"crypto"
"crypto/x509"
"fmt"
"strings"
@ -276,7 +278,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.AuthorizeSign(tt.args.token)
ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
@ -286,7 +289,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
} else {
assert.NotNil(t, got)
if tt.name == "admin" {
assert.Len(t, 4, got)
assert.Len(t, 3, got)
} else {
assert.Len(t, 5, got)
}
@ -295,6 +298,117 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
}
}
func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
tm, fn := mockNow()
defer fn()
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p2.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
key, err := generateJSONWebKey()
assert.FatalError(t, err)
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SSHOptions{
CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
}
expectedAdminOptions := &SSHOptions{
CertType: "user", Principals: []string{"root"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
}
expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
}
type args struct {
token string
sshOpts SSHOptions
}
tests := []struct {
name string
prov *OIDC
args args
expected *SSHOptions
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
{"admin", p3, args{okAdmin, SSHOptions{}}, expectedAdminOptions, false, false},
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}}, expectedAdminOptions, false, false},
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}}, expectedAdminOptions, false, false},
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}}, nil, false, true},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}}, nil, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}}, nil, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {
assert.Nil(t, cert)
} else {
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
}
}
}
})
}
}
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()

@ -1,9 +1,11 @@
package provisioner
import (
"context"
"crypto/x509"
"encoding/json"
"net/url"
"regexp"
"strings"
"github.com/pkg/errors"
@ -17,7 +19,7 @@ type Interface interface {
GetType() Type
GetEncryptedKey() (kid string, key string, ok bool)
Init(config Config) error
AuthorizeSign(token string) ([]SignOption, error)
AuthorizeSign(ctx context.Context, token string) ([]SignOption, error)
AuthorizeRenewal(cert *x509.Certificate) error
AuthorizeRevoke(token string) error
}
@ -169,3 +171,29 @@ func (l *List) UnmarshalJSON(data []byte) error {
return nil
}
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}

@ -1,6 +1,8 @@
package provisioner
import "testing"
import (
"testing"
)
func TestType_String(t *testing.T) {
tests := []struct {
@ -24,3 +26,29 @@ func TestType_String(t *testing.T) {
})
}
}
func TestSanitizeSSHUserPrincipal(t *testing.T) {
type args struct {
email string
}
tests := []struct {
name string
args args
want string
}{
{"simple", args{"foobar"}, "foobar"},
{"camelcase", args{"FooBar"}, "foobar"},
{"email", args{"foo@example.com"}, "foo"},
{"email with dots", args{"foo.bar.zar@example.com"}, "foobarzar"},
{"email with dashes", args{"foo-bar-zar@example.com"}, "foo-bar-zar"},
{"email with underscores", args{"foo_bar_zar@example.com"}, "foo_bar_zar"},
{"email with symbols", args{"Foo.Bar0123456789!#$%&'*+-/=?^_`{|}~;@example.com"}, "foobar0123456789________-___________"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := SanitizeSSHUserPrincipal(tt.args.email); got != tt.want {
t.Errorf("SanitizeSSHUserPrincipal() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,306 @@
package provisioner
import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
)
const (
// SSHUserCert is the string used to represent ssh.UserCert.
SSHUserCert = "user"
// SSHHostCert is the string used to represent ssh.HostCert.
SSHHostCert = "host"
)
// SSHCertificateModifier is the interface used to change properties in an SSH
// certificate.
type SSHCertificateModifier interface {
SignOption
Modify(cert *ssh.Certificate) error
}
// SSHCertificateOptionModifier is the interface used to add custom options used
// to modify the SSH certificate.
type SSHCertificateOptionModifier interface {
SignOption
Option(o SSHOptions) SSHCertificateModifier
}
// SSHCertificateValidator is the interface used to validate an SSH certificate.
type SSHCertificateValidator interface {
SignOption
Valid(cert *ssh.Certificate) error
}
// SSHCertificateOptionsValidator is the interface used to validate the custom
// options used to modify the SSH certificate.
type SSHCertificateOptionsValidator interface {
SignOption
Valid(got SSHOptions) error
}
// SSHOptions contains the options that can be passed to the SignSSH method.
type SSHOptions struct {
CertType string `json:"certType"`
Principals []string `json:"principals"`
ValidAfter TimeDuration `json:"validAfter,omitempty"`
ValidBefore TimeDuration `json:"validBefore,omitempty"`
}
// Type returns the uint32 representation of the CertType.
func (o SSHOptions) Type() uint32 {
return sshCertTypeUInt32(o.CertType)
}
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
switch o.CertType {
case "": // ignore
case SSHUserCert:
cert.CertType = ssh.UserCert
case SSHHostCert:
cert.CertType = ssh.HostCert
default:
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
}
cert.ValidPrincipals = o.Principals
if !o.ValidAfter.IsZero() {
cert.ValidAfter = uint64(o.ValidAfter.Time().Unix())
}
if !o.ValidBefore.IsZero() {
cert.ValidBefore = uint64(o.ValidBefore.Time().Unix())
}
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errors.New("ssh certificate valid after cannot be greater than valid before")
}
return nil
}
// match compares two SSHOptions and return an error if they don't match. It
// ignores zero values.
func (o SSHOptions) match(got SSHOptions) error {
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
}
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
}
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
}
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
}
return nil
}
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
// Key ID in the SSH certificate.
type sshCertificateKeyIDModifier string
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
cert.KeyId = string(m)
return nil
}
// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the
// certificate type to the SSH certificate.
type sshCertificateCertTypeModifier string
func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error {
cert.CertType = sshCertTypeUInt32(string(m))
return nil
}
// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the
// principals to the SSH certificate.
type sshCertificatePrincipalsModifier []string
func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error {
cert.ValidPrincipals = []string(m)
return nil
}
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
// ValidAfter in the SSH certificate.
type sshCertificateValidAfterModifier uint64
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
cert.ValidAfter = uint64(m)
return nil
}
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
// ValidBefore in the SSH certificate.
type sshCertificateValidBeforeModifier uint64
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
cert.ValidBefore = uint64(m)
return nil
}
// sshCertificateDefaultModifier implements a SSHCertificateModifier that
// modifies the certificate with the given options if they are not set.
type sshCertificateDefaultsModifier SSHOptions
// Modify implements the SSHCertificateModifier interface.
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
if cert.CertType == 0 {
cert.CertType = sshCertTypeUInt32(m.CertType)
}
if len(cert.ValidPrincipals) == 0 {
cert.ValidPrincipals = m.Principals
}
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
cert.ValidAfter = uint64(m.ValidAfter.Unix())
}
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
cert.ValidBefore = uint64(m.ValidBefore.Unix())
}
return nil
}
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
// the default extensions in an SSH certificate.
type sshDefaultExtensionModifier struct{}
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
switch cert.CertType {
// Default to no extensions for HostCert.
case ssh.HostCert:
return nil
case ssh.UserCert:
if cert.Extensions == nil {
cert.Extensions = make(map[string]string)
}
cert.Extensions["permit-X11-forwarding"] = ""
cert.Extensions["permit-agent-forwarding"] = ""
cert.Extensions["permit-port-forwarding"] = ""
cert.Extensions["permit-pty"] = ""
cert.Extensions["permit-user-rc"] = ""
return nil
default:
return errors.New("ssh certificate type has not been set or is invalid")
}
}
// sshCertificateValidityModifier is a SSHCertificateModifier checks the
// validity bounds, setting them if they are not provided. It will fail if a
// CertType has not been set or is not valid.
type sshCertificateValidityModifier struct {
*Claimer
}
func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
var d, min, max time.Duration
switch cert.CertType {
case ssh.UserCert:
d = m.DefaultUserSSHCertDuration()
min = m.MinUserSSHCertDuration()
max = m.MaxUserSSHCertDuration()
case ssh.HostCert:
d = m.DefaultHostSSHCertDuration()
min = m.MinHostSSHCertDuration()
max = m.MaxHostSSHCertDuration()
case 0:
return errors.New("ssh certificate type has not been set")
default:
return errors.Errorf("unknown ssh certificate type %d", cert.CertType)
}
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now().Unix())
}
if cert.ValidBefore == 0 {
t := time.Unix(int64(cert.ValidAfter), 0)
cert.ValidBefore = uint64(t.Add(d).Unix())
}
diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
switch {
case diff < min:
return errors.Errorf("ssh certificate duration cannot be lower than %s", min)
case diff > max:
return errors.Errorf("ssh certificate duration cannot be greater than %s", max)
default:
return nil
}
}
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
// usually present in the token.
type sshCertificateOptionsValidator SSHOptions
// Valid implements SSHCertificateOptionsValidator and returns nil if both
// SSHOptions match.
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
want := SSHOptions(v)
return want.match(got)
}
// sshCertificateDefaultValidator implements a simple validator for all the
// fields in the SSH certificate.
type sshCertificateDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the necessary fields.
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
switch {
case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty")
case cert.Key == nil:
return errors.New("ssh certificate key cannot be nil")
case cert.Serial == 0:
return errors.New("ssh certificate serial cannot be 0")
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType)
case cert.KeyId == "":
return errors.New("ssh certificate key id cannot be empty")
case len(cert.ValidPrincipals) == 0:
return errors.New("ssh certificate valid principals cannot be empty")
case cert.ValidAfter == 0:
return errors.New("ssh certificate valid after cannot be 0")
case cert.ValidBefore == 0:
return errors.New("ssh certificate valid before cannot be 0")
case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0:
return errors.New("ssh certificate extensions cannot be empty")
case cert.SignatureKey == nil:
return errors.New("ssh certificate signature key cannot be nil")
case cert.Signature == nil:
return errors.New("ssh certificate signature cannot be nil")
default:
return nil
}
}
// sshCertTypeUInt32
func sshCertTypeUInt32(ct string) uint32 {
switch ct {
case SSHUserCert:
return ssh.UserCert
case SSHHostCert:
return ssh.HostCert
default:
return 0
}
}
// containsAllMembers reports whether all members of subgroup are within group.
func containsAllMembers(group, subgroup []string) bool {
lg, lsg := len(group), len(subgroup)
if lsg > lg || (lg > 0 && lsg == 0) {
return false
}
visit := make(map[string]struct{}, lg)
for i := 0; i < lg; i++ {
visit[group[i]] = struct{}{}
}
for i := 0; i < lsg; i++ {
if _, ok := visit[subgroup[i]]; !ok {
return false
}
}
return true
}

@ -0,0 +1,125 @@
package provisioner
import (
"crypto"
"crypto/rand"
"fmt"
"reflect"
"time"
"golang.org/x/crypto/ssh"
)
func validateSSHCertificate(cert *ssh.Certificate, opts *SSHOptions) error {
switch {
case cert == nil:
return fmt.Errorf("certificate is nil")
case cert.Signature == nil:
return fmt.Errorf("certificate signature is nil")
case cert.SignatureKey == nil:
return fmt.Errorf("certificate signature is nil")
case !reflect.DeepEqual(cert.ValidPrincipals, opts.Principals):
return fmt.Errorf("certificate principals are not equal, want %v, got %v", opts.Principals, cert.ValidPrincipals)
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
return fmt.Errorf("certificate type %v is not valid", cert.CertType)
case opts.CertType == "user" && cert.CertType != ssh.UserCert:
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.UserCert, cert.CertType)
case opts.CertType == "host" && cert.CertType != ssh.HostCert:
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType)
case cert.ValidAfter != uint64(opts.ValidAfter.Unix()):
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
case cert.ValidBefore != uint64(opts.ValidBefore.Unix()):
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
case opts.CertType == "user" && len(cert.Extensions) != 5:
return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions))
case opts.CertType == "host" && len(cert.Extensions) != 0:
return fmt.Errorf("certificate extensions number is invalid, want 0, got %d", len(cert.Extensions))
default:
return nil
}
}
func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOption, signKey crypto.Signer) (*ssh.Certificate, error) {
pub, err := ssh.NewPublicKey(key)
if err != nil {
return nil, err
}
var mods []SSHCertificateModifier
var validators []SSHCertificateValidator
for _, op := range signOpts {
switch o := op.(type) {
// modify the ssh.Certificate
case SSHCertificateModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case SSHCertificateOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case SSHCertificateValidator:
validators = append(validators, o)
// validate the given SSHOptions
case SSHCertificateOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("signSSH: invalid extra option type %T", o)
}
}
// Build base certificate with the key and some random values
cert := &ssh.Certificate{
Nonce: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0},
Key: pub,
Serial: 1234567890,
}
// Use opts to modify the certificate
if err := opts.Modify(cert); err != nil {
return nil, err
}
// Use provisioner modifiers
for _, m := range mods {
if err := m.Modify(cert); err != nil {
return nil, err
}
}
// Get signer from authority keys
var signer ssh.Signer
switch cert.CertType {
case ssh.UserCert:
signer, err = ssh.NewSignerFromSigner(signKey)
case ssh.HostCert:
signer, err = ssh.NewSignerFromSigner(signKey)
default:
return nil, fmt.Errorf("unexpected ssh certificate type: %d", cert.CertType)
}
if err != nil {
return nil, err
}
cert.SignatureKey = signer.PublicKey()
// Get bytes for signing trailing the signature length.
data := cert.Marshal()
data = data[:len(data)-4]
// Sign the certificate
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, err
}
cert.Signature = sig
// User provisioners validators
for _, v := range validators {
if err := v.Valid(cert); err != nil {
return nil, err
}
}
return cert, nil
}

@ -57,6 +57,17 @@ func (t *TimeDuration) SetTime(tt time.Time) {
t.t, t.d = tt, 0
}
// IsZero returns true the TimeDuration represents the zero value, false
// otherwise.
func (t *TimeDuration) IsZero() bool {
return t.t.IsZero() && t.d == 0
}
// Equal returns if t and other are equal.
func (t *TimeDuration) Equal(other *TimeDuration) bool {
return t.t.Equal(other.t) && t.d == other.d
}
// MarshalJSON implements the json.Marshaler interface. If the time is set it
// will return the time in RFC 3339 format if not it will return the duration
// string.
@ -64,7 +75,7 @@ func (t TimeDuration) MarshalJSON() ([]byte, error) {
switch {
case t.t.IsZero():
if t.d == 0 {
return []byte("null"), nil
return []byte(`""`), nil
}
return json.Marshal(t.d.String())
default:
@ -102,11 +113,16 @@ func (t *TimeDuration) UnmarshalJSON(data []byte) error {
return errors.Errorf("failed to parse %s", data)
}
// Time calculates the embedded time.Time, sets it if necessary, and returns it.
// Time calculates the time if needed and returns it.
func (t *TimeDuration) Time() time.Time {
return t.RelativeTime(now())
}
// Unix calculates the time if needed it and returns the Unix time in seconds.
func (t *TimeDuration) Unix() int64 {
return t.RelativeTime(now()).Unix()
}
// RelativeTime returns the embedded time.Time or the base time plus the
// duration if this is not zero.
func (t *TimeDuration) RelativeTime(base time.Time) time.Time {

@ -6,6 +6,17 @@ import (
"time"
)
func mockNow() (time.Time, func()) {
tm := time.Unix(1584198566, 535897000).UTC()
nowFn := now
now = func() time.Time {
return tm
}
return tm, func() {
now = nowFn
}
}
func TestNewTimeDuration(t *testing.T) {
tm := time.Unix(1584198566, 535897000).UTC()
type args struct {
@ -137,7 +148,7 @@ func TestTimeDuration_MarshalJSON(t *testing.T) {
want []byte
wantErr bool
}{
{"null", TimeDuration{}, []byte("null"), false},
{"empty", TimeDuration{}, []byte(`""`), false},
{"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false},
{"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false},
{"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true},
@ -166,7 +177,7 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) {
want *TimeDuration
wantErr bool
}{
{"null", args{[]byte("null")}, &TimeDuration{}, false},
{"empty", args{[]byte(`""`)}, &TimeDuration{}, false},
{"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false},
{"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false},
{"fail", args{[]byte("123")}, &TimeDuration{}, true},
@ -186,15 +197,8 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) {
}
func TestTimeDuration_Time(t *testing.T) {
nowFn := now
defer func() {
now = nowFn
now()
}()
tm := time.Unix(1584198566, 535897000).UTC()
now = func() time.Time {
return tm
}
tm, fn := mockNow()
defer fn()
tests := []struct {
name string
timeDuration *TimeDuration
@ -211,6 +215,30 @@ func TestTimeDuration_Time(t *testing.T) {
got := tt.timeDuration.Time()
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want)
}
})
}
}
func TestTimeDuration_Unix(t *testing.T) {
tm, fn := mockNow()
defer fn()
tests := []struct {
name string
timeDuration *TimeDuration
want int64
}{
{"zero", nil, -62135596800},
{"zero", &TimeDuration{}, -62135596800},
{"timestamp", &TimeDuration{t: tm}, 1584198566},
{"local", &TimeDuration{t: tm.Local()}, 1584198566},
{"duration", &TimeDuration{d: 1 * time.Hour}, 1584202166},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.timeDuration.Unix()
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("TimeDuration.Unix() = %v, want %v", got, tt.want)
}
})
@ -218,15 +246,8 @@ func TestTimeDuration_Time(t *testing.T) {
}
func TestTimeDuration_String(t *testing.T) {
nowFn := now
defer func() {
now = nowFn
now()
}()
tm := time.Unix(1584198566, 535897000).UTC()
now = func() time.Time {
return tm
}
tm, fn := mockNow()
defer fn()
type fields struct {
t time.Time
d time.Duration

@ -480,6 +480,54 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{
CertType: "user",
Principals: []string{"name"},
}, jwk)
}
func generateSimpleSSHHostToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{
CertType: "host",
Principals: []string{"smallstep.com"},
}, jwk)
}
func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *SSHOptions, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
if err != nil {
return "", err
}
id, err := randutil.ASCII(64)
if err != nil {
return "", err
}
claims := struct {
jose.Claims
Step *stepPayload `json:"step,omitempty"`
}{
Claims: jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
},
Step: &stepPayload{
SSH: sshOpts,
},
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},

@ -13,7 +13,7 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) {
key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
http.StatusNotFound, context{}}
http.StatusNotFound, apiCtx{}}
}
return key, nil
}
@ -31,7 +31,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, context{}}
http.StatusNotFound, apiCtx{}}
}
return p, nil
}

@ -35,7 +35,7 @@ func TestGetEncryptedKey(t *testing.T) {
a: a,
kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
http.StatusNotFound, context{}},
http.StatusNotFound, apiCtx{}},
}
},
}

@ -12,13 +12,13 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
val, ok := a.certificates.Load(sum)
if !ok {
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
http.StatusNotFound, context{}}
http.StatusNotFound, apiCtx{}}
}
crt, ok := val.(*x509.Certificate)
if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, context{}}
http.StatusInternalServerError, apiCtx{}}
}
return crt, nil
}
@ -53,7 +53,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
if !ok {
federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, context{}}
http.StatusInternalServerError, apiCtx{}}
return false
}
federation = append(federation, crt)

@ -19,8 +19,8 @@ func TestRoot(t *testing.T) {
sum string
err *apiError
}{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
}

@ -0,0 +1,239 @@
package authority
import (
"crypto/rand"
"encoding/binary"
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/randutil"
"golang.org/x/crypto/ssh"
)
const (
// SSHAddUserPrincipal is the principal that will run the add user command.
// Defaults to "provisioner" but it can be changed in the configuration.
SSHAddUserPrincipal = "provisioner"
// SSHAddUserCommand is the default command to run to add a new user.
// Defaults to "sudo useradd -m <principal>; nc -q0 localhost 22" but it can be changed in the
// configuration. The string "<principal>" will be replace by the new
// principal to add.
SSHAddUserCommand = "sudo useradd -m <principal>; nc -q0 localhost 22"
)
// SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var mods []provisioner.SSHCertificateModifier
var validators []provisioner.SSHCertificateValidator
for _, op := range signOpts {
switch o := op.(type) {
// modify the ssh.Certificate
case provisioner.SSHCertificateModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertificateOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case provisioner.SSHCertificateValidator:
validators = append(validators, o)
// validate the given SSHOptions
case provisioner.SSHCertificateOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
}
default:
return nil, &apiError{
err: errors.Errorf("signSSH: invalid extra option type %T", o),
code: http.StatusInternalServerError,
}
}
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError}
}
var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error reading random number"),
code: http.StatusInternalServerError,
}
}
// Build base certificate with the key and some random values
cert := &ssh.Certificate{
Nonce: []byte(nonce),
Key: key,
Serial: serial,
}
// Use opts to modify the certificate
if err := opts.Modify(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
}
// Use provisioner modifiers
for _, m := range mods {
if err := m.Modify(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
}
}
// Get signer from authority keys
var signer ssh.Signer
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, &apiError{
err: errors.New("signSSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
}
if signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error creating signer"),
code: http.StatusInternalServerError,
}
}
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, &apiError{
err: errors.New("signSSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
}
if signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error creating signer"),
code: http.StatusInternalServerError,
}
}
default:
return nil, &apiError{
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
}
cert.SignatureKey = signer.PublicKey()
// Get bytes for signing trailing the signature length.
data := cert.Marshal()
data = data[:len(data)-4]
// Sign the certificate
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error signing certificate"),
code: http.StatusInternalServerError,
}
}
cert.Signature = sig
// User provisioners validators
for _, v := range validators {
if err := v.Valid(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
}
}
return cert, nil
}
// SignSSHAddUser signs a certificate that provisions a new user in a server.
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
if a.sshCAUserCertSignKey == nil {
return nil, &apiError{
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
}
if subject.CertType != ssh.UserCert {
return nil, &apiError{
err: errors.New("signSSHProxy: certificate is not a user certificate"),
code: http.StatusForbidden,
}
}
if len(subject.ValidPrincipals) != 1 {
return nil, &apiError{
err: errors.New("signSSHProxy: certificate does not have only one principal"),
code: http.StatusForbidden,
}
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError}
}
var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSHProxy: error reading random number"),
code: http.StatusInternalServerError,
}
}
signer, err := ssh.NewSignerFromSigner(a.sshCAUserCertSignKey)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSHProxy: error creating signer"),
code: http.StatusInternalServerError,
}
}
principal := subject.ValidPrincipals[0]
addUserPrincipal := a.getAddUserPrincipal()
cert := &ssh.Certificate{
Nonce: []byte(nonce),
Key: key,
Serial: serial,
CertType: ssh.UserCert,
KeyId: principal + "-" + addUserPrincipal,
ValidPrincipals: []string{addUserPrincipal},
ValidAfter: subject.ValidAfter,
ValidBefore: subject.ValidBefore,
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{
"force-command": a.getAddUserCommand(principal),
},
},
SignatureKey: signer.PublicKey(),
}
// Get bytes for signing trailing the signature length.
data := cert.Marshal()
data = data[:len(data)-4]
// Sign the certificate
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, err
}
cert.Signature = sig
return cert, nil
}
func (a *Authority) getAddUserPrincipal() (cmd string) {
if a.config.SSH.AddUserPrincipal == "" {
return SSHAddUserPrincipal
}
return a.config.SSH.AddUserPrincipal
}
func (a *Authority) getAddUserCommand(principal string) string {
var cmd string
if a.config.SSH.AddUserCommand == "" {
cmd = SSHAddUserCommand
} else {
cmd = a.config.SSH.AddUserCommand
}
return strings.Replace(cmd, "<principal>", principal, -1)
}

@ -0,0 +1,252 @@
package authority
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"golang.org/x/crypto/ssh"
)
type sshTestModifier ssh.Certificate
func (m sshTestModifier) Modify(cert *ssh.Certificate) error {
if m.CertType != 0 {
cert.CertType = m.CertType
}
if m.KeyId != "" {
cert.KeyId = m.KeyId
}
if m.ValidAfter != 0 {
cert.ValidAfter = m.ValidAfter
}
if m.ValidBefore != 0 {
cert.ValidBefore = m.ValidBefore
}
if len(m.ValidPrincipals) != 0 {
cert.ValidPrincipals = m.ValidPrincipals
}
if m.Permissions.CriticalOptions != nil {
cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions
}
if m.Permissions.Extensions != nil {
cert.Permissions.Extensions = m.Permissions.Extensions
}
return nil
}
type sshTestCertModifier string
func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
if m == "" {
return nil
}
return fmt.Errorf(string(m))
}
type sshTestCertValidator string
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error {
if v == "" {
return nil
}
return fmt.Errorf(string(v))
}
type sshTestOptionsValidator string
func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
if v == "" {
return nil
}
return fmt.Errorf(string(v))
}
type sshTestOptionsModifier string
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
return sshTestCertModifier(string(m))
}
func TestAuthority_SignSSH(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
pub, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
userOptions := sshTestModifier{
CertType: ssh.UserCert,
}
hostOptions := sshTestModifier{
CertType: ssh.HostCert,
}
now := time.Now()
type fields struct {
sshCAUserCertSignKey crypto.Signer
sshCAHostCertSignKey crypto.Signer
}
type args struct {
key ssh.PublicKey
opts provisioner.SSHOptions
signOpts []provisioner.SignOption
}
type want struct {
CertType uint32
Principals []string
ValidAfter uint64
ValidBefore uint64
}
tests := []struct {
name string
fields fields
args args
want want
wantErr bool
}{
{"ok-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-type-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-type-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
{"ok-opts-valid-after", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false},
{"ok-opts-valid-before", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false},
{"ok-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false},
{"ok-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false},
{"fail-opts-type", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "foo"}, []provisioner.SignOption{}}, want{}, true},
{"fail-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("an error")}}, want{}, true},
{"fail-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("an error")}}, want{}, true},
{"fail-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("an error")}}, want{}, true},
{"fail-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("an error")}}, want{}, true},
{"fail-bad-sign-options", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, "wrong type"}}, want{}, true},
{"fail-no-user-key", fields{nil, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{}, true},
{"fail-no-host-key", fields{signKey, nil}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{}, true},
{"fail-bad-type", fields{signKey, nil}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{sshTestModifier{CertType: 0}}}, want{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && assert.NotNil(t, got) {
assert.Equals(t, tt.want.CertType, got.CertType)
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
assert.NotNil(t, got.Key)
assert.NotNil(t, got.Nonce)
assert.NotEquals(t, 0, got.Serial)
assert.NotNil(t, got.Signature)
assert.NotNil(t, got.SignatureKey)
}
})
}
}
func TestAuthority_SignSSHAddUser(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
pub, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
type fields struct {
sshCAUserCertSignKey crypto.Signer
sshCAHostCertSignKey crypto.Signer
addUserPrincipal string
addUserCommand string
}
type args struct {
key ssh.PublicKey
subject *ssh.Certificate
}
type want struct {
CertType uint32
Principals []string
ValidAfter uint64
ValidBefore uint64
ForceCommand string
}
now := time.Now()
validCert := &ssh.Certificate{
CertType: ssh.UserCert,
ValidPrincipals: []string{"user"},
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}
validWant := want{
CertType: ssh.UserCert,
Principals: []string{"provisioner"},
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
ForceCommand: "sudo useradd -m user; nc -q0 localhost 22",
}
tests := []struct {
name string
fields fields
args args
want want
wantErr bool
}{
{"ok", fields{signKey, signKey, "", ""}, args{pub, validCert}, validWant, false},
{"ok-no-host-key", fields{signKey, nil, "", ""}, args{pub, validCert}, validWant, false},
{"ok-custom-principal", fields{signKey, signKey, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false},
{"ok-custom-command", fields{signKey, signKey, "", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false},
{"ok-custom-principal-and-command", fields{signKey, signKey, "my-principal", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false},
{"fail-no-user-key", fields{nil, signKey, "", ""}, args{pub, validCert}, want{}, true},
{"fail-no-user-cert", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true},
{"fail-no-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true},
{"fail-many-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
a.config.SSH = &SSHConfig{
AddUserPrincipal: tt.fields.addUserPrincipal,
AddUserCommand: tt.fields.addUserCommand,
}
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && assert.NotNil(t, got) {
assert.Equals(t, tt.want.CertType, got.CertType)
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId)
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions)
assert.Equals(t, nil, got.Extensions)
assert.NotNil(t, got.Key)
assert.NotNil(t, got.Nonce)
assert.NotEquals(t, 0, got.Serial)
assert.NotNil(t, got.Signature)
assert.NotNil(t, got.SignatureKey)
}
})
}
}

@ -58,7 +58,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
// Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
var (
errContext = context{"csr": csr, "signOptions": signOpts}
errContext = apiCtx{"csr": csr, "signOptions": signOpts}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{}
issIdentity = a.intermediateIdentity
@ -181,23 +181,23 @@ func (a *Authority) Renew(oldCert *x509.Certificate) (*x509.Certificate, *x509.C
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
issIdentity.Crt, issIdentity.Key)
if err != nil {
return nil, nil, &apiError{err, http.StatusInternalServerError, context{}}
return nil, nil, &apiError{err, http.StatusInternalServerError, apiCtx{}}
}
crtBytes, err := leaf.CreateCertificate()
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
http.StatusInternalServerError, context{}}
http.StatusInternalServerError, apiCtx{}}
}
serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
http.StatusInternalServerError, context{}}
http.StatusInternalServerError, apiCtx{}}
}
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
http.StatusInternalServerError, context{}}
http.StatusInternalServerError, apiCtx{}}
}
return serverCert, caCert, nil
@ -222,7 +222,7 @@ type RevokeOptions struct {
//
// TODO: Add OCSP and CRL support.
func (a *Authority) Revoke(opts *RevokeOptions) error {
errContext := context{
errContext := apiCtx{
"serialNumber": opts.Serial,
"reasonCode": opts.ReasonCode,
"reason": opts.Reason,

@ -1,6 +1,7 @@
package authority
import (
"context"
"crypto/rand"
"crypto/sha1"
"crypto/x509"
@ -103,7 +104,8 @@ func TestSign(t *testing.T) {
assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err)
extraOpts, err := a.Authorize(token)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
extraOpts, err := a.Authorize(ctx, token)
assert.FatalError(t, err)
type signTest struct {
@ -124,7 +126,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid certificate request"),
http.StatusBadRequest,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -138,7 +140,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid extra option type string"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -153,7 +155,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -168,7 +170,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -185,7 +187,7 @@ func TestSign(t *testing.T) {
signOpts: _signOpts,
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
http.StatusUnauthorized,
context{"csr": csr, "signOptions": _signOpts},
apiCtx{"csr": csr, "signOptions": _signOpts},
},
}
},
@ -200,7 +202,7 @@ func TestSign(t *testing.T) {
signOpts: signOpts,
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
http.StatusUnauthorized,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -227,7 +229,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
signOpts: signOpts,
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
http.StatusUnauthorized,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -238,7 +240,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
storeCertificate: func(crt *x509.Certificate) error {
return &apiError{errors.New("force"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}}
apiCtx{"csr": csr, "signOptions": signOpts}}
},
}
return &signTest{
@ -248,7 +250,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
signOpts: signOpts,
err: &apiError{errors.New("sign: error storing certificate in db: force"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
apiCtx{"csr": csr, "signOptions": signOpts},
},
}
},
@ -401,7 +403,7 @@ func TestRenew(t *testing.T) {
auth: _a,
crt: crt,
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
http.StatusInternalServerError, context{}},
http.StatusInternalServerError, apiCtx{}},
}, nil
},
"fail-unauthorized": func() (*renewTest, error) {
@ -596,7 +598,7 @@ func TestRevoke(t *testing.T) {
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
now := time.Now().UTC()
getCtx := func() map[string]interface{} {
return context{
return apiCtx{
"serialNumber": "sn",
"reasonCode": reasonCode,
"reason": reason,

@ -373,6 +373,28 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
return &sign, nil
}
// SignSSH performs the SSH certificate sign request to the CA and returns the
// api.SignSSHResponse struct.
func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"})
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var sign api.SignSSHResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &sign, nil
}
// Renew performs the renew request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {

@ -280,7 +280,11 @@ func stringifyFlag(f cli.Flag) string {
usage := fv.FieldByName("Usage").String()
placeholder := placeholderString.FindString(usage)
if placeholder == "" {
placeholder = "<value>"
switch f.(type) {
case cli.BoolFlag, cli.BoolTFlag:
default:
placeholder = "<value>"
}
}
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage
}

Loading…
Cancel
Save