Instrument getIdentity func for OIDC ssh provisioner

pull/166/head^2
max furman 5 years ago
parent 3d970b45c8
commit 414a94b210

@ -41,7 +41,7 @@ type Authority struct {
initOnce bool
// Custom functions
sshBastionFunc func(user, hostname string) (*Bastion, error)
getIdentityFunc func(p provisioner.Interface, email string) (*provisioner.Identity, error)
getIdentityFunc provisioner.GetIdentityFunc
}
// New creates and initiates a new Authority type.
@ -192,6 +192,7 @@ func (a *Authority) init() error {
UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys,
},
GetIdentityFunc: a.getIdentityFunc,
}
// Store all the provisioners
for _, p := range a.config.AuthorityConfig.Provisioners {

@ -197,10 +197,10 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Add modifiers from custom claims
// FIXME: this is also set in the sign method using SSHOptions.Modify.
if opts.CertType != "" {
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
signOptions = append(signOptions, sshCertTypeModifier(opts.CertType))
}
if len(opts.Principals) > 0 {
signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals))
signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals))
}
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))

@ -64,6 +64,7 @@ type OIDC struct {
configuration openIDConfiguration
keyStore *keyStore
claimer *Claimer
getIdentityFunc GetIdentityFunc
}
// IsAdmin returns true if the given email is in the Admins whitelist, false
@ -169,6 +170,13 @@ func (o *OIDC) Init(config Config) (err error) {
if err != nil {
return err
}
// Set the identity getter if it exists, otherwise use the default.
if config.GetIdentityFunc == nil {
o.getIdentityFunc = DefaultIdentityFunc
} else {
o.getIdentityFunc = config.GetIdentityFunc
}
return nil
}
@ -326,23 +334,26 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
sshCertificateKeyIDModifier(claims.Email),
}
name := SanitizeSSHUserPrincipal(claims.Email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email address '%s'", name, claims.Email)
// Get the identity using either the default identityFunc or one injected
// externally.
iden, err := o.getIdentityFunc(o, claims.Email)
if err != nil {
return nil, errors.Wrap(err, "authorizeSSHSign")
}
// Admin users will default to user + name but they can be changed by the
// user options. Non-admins are only able to sign user certificates.
defaults := SSHOptions{
CertType: SSHUserCert,
Principals: []string{name},
Principals: iden.Usernames,
}
// Admin users can use any principal, and can sign user and host certificates.
// Non-admin users can only use principals returned by the identityFunc, and
// can only sign user certificates.
if !o.IsAdmin(claims.Email) {
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
}
// Default to a user with name as principal if not set
// Default to a user certificate with usernames as principals if those options
// are not set.
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
return append(signOptions,

@ -347,6 +347,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
p4, err := generateOIDC()
assert.FatalError(t, err)
p5, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
@ -356,12 +360,27 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p4.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p5.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p2.Init(config))
assert.FatalError(t, p3.Init(config))
assert.FatalError(t, p4.Init(config))
assert.FatalError(t, p5.Init(config))
p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"max", "mariano"}}, nil
}
p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
return nil, errors.New("force")
}
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
okGetIdentityToken, err := generateSimpleToken("the-issuer", p4.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
failGetIdentityToken, err := generateSimpleToken("the-issuer", p5.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
@ -384,11 +403,11 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SSHOptions{
CertType: "user", Principals: []string{"name"},
CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
}
expectedAdminOptions := &SSHOptions{
CertType: "user", Principals: []string{"root"},
CertType: "user", Principals: []string{"root", "root@example.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
}
expectedHostOptions := &SSHOptions{
@ -412,17 +431,32 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub},
&SSHOptions{CertType: "user", Principals: []string{"max", "mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false},
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false},
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, expectedAdminOptions, false, false},
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"root"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false},
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

@ -185,6 +185,9 @@ type Config struct {
DB db.AuthDB
// SSHKeys are the root SSH public keys
SSHKeys *SSHKeys
// GetIdentityFunc is a function that returns an identity that will be
// used by the provisioner to populate certificate attributes.
GetIdentityFunc GetIdentityFunc
}
type provisioner struct {
@ -314,7 +317,7 @@ func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certif
}
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing SSH Certificates.
// this method if they will support authorizing tokens for rekeying SSH Certificates.
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRekey")
}
@ -325,6 +328,23 @@ type Identity struct {
Usernames []string `json:"usernames"`
}
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(p Interface, email string) (*Identity, error)
// DefaultIdentityFunc return a default identity depending on the provisioner type.
func DefaultIdentityFunc(p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
return &Identity{Usernames: []string{name, email}}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// MockProvisioner for testing
type MockProvisioner struct {
Mret1, Mret2, Mret3 interface{}
@ -335,9 +355,13 @@ type MockProvisioner struct {
MgetType func() Type
MgetEncryptedKey func() (string, string, bool)
Minit func(Config) error
MauthorizeRevoke func(ott string) error
MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error)
MauthorizeRenewal func(*x509.Certificate) error
MauthorizeRenew func(ctx context.Context, cert *x509.Certificate) error
MauthorizeRevoke func(ctx context.Context, ott string) error
MauthorizeSSHSign func(ctx context.Context, ott string) ([]SignOption, error)
MauthorizeSSHRenew func(ctx context.Context, ott string) (*ssh.Certificate, error)
MauthorizeSSHRekey func(ctx context.Context, ott string) (*ssh.Certificate, []SignOption, error)
MauthorizeSSHRevoke func(ctx context.Context, ott string) error
}
// GetID mock
@ -391,26 +415,58 @@ func (m *MockProvisioner) Init(c Config) error {
return m.Merr
}
// AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) {
if m.MauthorizeSign != nil {
return m.MauthorizeSign(ctx, ott)
}
return m.Mret1.([]SignOption), m.Merr
}
// AuthorizeRevoke mock
func (m *MockProvisioner) AuthorizeRevoke(ott string) error {
func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, ott string) error {
if m.MauthorizeRevoke != nil {
return m.MauthorizeRevoke(ott)
return m.MauthorizeRevoke(ctx, ott)
}
return m.Merr
}
// AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) {
// AuthorizeRenew mock
func (m *MockProvisioner) AuthorizeRenew(ctx context.Context, c *x509.Certificate) error {
if m.MauthorizeRenew != nil {
return m.MauthorizeRenew(ctx, c)
}
return m.Merr
}
// AuthorizeSSHSign mock
func (m *MockProvisioner) AuthorizeSSHSign(ctx context.Context, ott string) ([]SignOption, error) {
if m.MauthorizeSign != nil {
return m.MauthorizeSign(ctx, ott)
}
return m.Mret1.([]SignOption), m.Merr
}
// AuthorizeRenewal mock
func (m *MockProvisioner) AuthorizeRenewal(c *x509.Certificate) error {
if m.MauthorizeRenewal != nil {
return m.MauthorizeRenewal(c)
// AuthorizeSSHRenew mock
func (m *MockProvisioner) AuthorizeSSHRenew(ctx context.Context, ott string) (*ssh.Certificate, error) {
if m.MauthorizeRenew != nil {
return m.MauthorizeSSHRenew(ctx, ott)
}
return m.Mret1.(*ssh.Certificate), m.Merr
}
// AuthorizeSSHRekey mock
func (m *MockProvisioner) AuthorizeSSHRekey(ctx context.Context, ott string) (*ssh.Certificate, []SignOption, error) {
if m.MauthorizeSSHRekey != nil {
return m.MauthorizeSSHRekey(ctx, ott)
}
return m.Mret1.(*ssh.Certificate), m.Mret2.([]SignOption), m.Merr
}
// AuthorizeSSHRevoke mock
func (m *MockProvisioner) AuthorizeSSHRevoke(ctx context.Context, ott string) error {
if m.MauthorizeSSHRevoke != nil {
return m.MauthorizeSSHRevoke(ctx, ott)
}
return m.Merr
}

@ -2,6 +2,9 @@ package provisioner
import (
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
)
func TestType_String(t *testing.T) {
@ -52,3 +55,49 @@ func TestSanitizeSSHUserPrincipal(t *testing.T) {
})
}
}
func TestDefaultIdentityFunc(t *testing.T) {
type test struct {
p Interface
email string
err error
identity *Identity
}
tests := map[string]func(*testing.T) test{
"fail/unsupported-provisioner": func(t *testing.T) test {
return test{
p: &X5C{},
err: errors.New("provisioner type '*provisioner.X5C' not supported by identity function"),
}
},
"fail/bad-ssh-regex": func(t *testing.T) test {
return test{
p: &OIDC{},
email: "$%^#_>@smallstep.com",
err: errors.New("invalid principal '______' from email '$%^#_>@smallstep.com'"),
}
},
"ok": func(t *testing.T) test {
return test{
p: &OIDC{},
email: "max.furman@smallstep.com",
identity: &Identity{Usernames: []string{"maxfurman", "max.furman@smallstep.com"}},
}
},
}
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
identity, err := DefaultIdentityFunc(tc.p, tc.email)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, identity.Usernames, tc.identity.Usernames)
}
}
})
}
}

@ -107,6 +107,16 @@ func (o SSHOptions) match(got SSHOptions) error {
return nil
}
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the
// principals to the SSH certificate.
type sshCertPrincipalsModifier []string
// Modify the ValidPrincipals value of the cert.
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
cert.ValidPrincipals = []string(o)
return nil
}
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
// Key ID in the SSH certificate.
type sshCertificateKeyIDModifier string
@ -116,24 +126,16 @@ func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the
// certificate type to the SSH certificate.
type sshCertificateCertTypeModifier string
// sshCertTypeModifier is an SSHCertificateModifier that sets the
// certificate type.
type sshCertTypeModifier string
func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error {
// Modify sets the CertType for the ssh certificate.
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
cert.CertType = sshCertTypeUInt32(string(m))
return nil
}
// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the
// principals to the SSH certificate.
type sshCertificatePrincipalsModifier []string
func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error {
cert.ValidPrincipals = []string(m)
return nil
}
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
// ValidAfter in the SSH certificate.
type sshCertificateValidAfterModifier uint64

@ -237,10 +237,10 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Add modifiers from custom claims
// FIXME: this is also set in the sign method using SSHOptions.Modify.
if opts.CertType != "" {
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
signOptions = append(signOptions, sshCertTypeModifier(opts.CertType))
}
if len(opts.Principals) > 0 {
signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals))
signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals))
}
t := now()
if !opts.ValidAfter.IsZero() {

@ -636,9 +636,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
case sshCertificateKeyIDModifier:
assert.Equals(t, string(v), "foo")
case sshCertificateCertTypeModifier:
case sshCertTypeModifier:
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
case sshCertificatePrincipalsModifier:
case sshCertPrincipalsModifier:
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
case sshCertificateValidAfterModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())

Loading…
Cancel
Save