diff --git a/Gopkg.lock b/Gopkg.lock index 03026d94..1f91a67b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -262,15 +262,20 @@ [[projects]] branch = "master" - digest = "1:5dd7da6df07f42194cb25d162b4b89664ed7b08d7d4334f6a288393d54b095ce" + digest = "1:afc49fe39c8c591fc2c8ddc73adc4c69e67125dde6c58e24c91b3b0cf78602be" name = "golang.org/x/crypto" packages = [ "cryptobyte", "cryptobyte/asn1", + "curve25519", "ed25519", "ed25519/internal/edwards25519", + "internal/chacha20", + "internal/subtle", "ocsp", "pbkdf2", + "poly1305", + "ssh", "ssh/terminal", ] pruneopts = "UT" @@ -394,6 +399,7 @@ "github.com/urfave/cli", "golang.org/x/crypto/ed25519", "golang.org/x/crypto/ocsp", + "golang.org/x/crypto/ssh", "golang.org/x/net/http2", "gopkg.in/square/go-jose.v2", "gopkg.in/square/go-jose.v2/jwt", diff --git a/api/api.go b/api/api.go index f1013c0a..fd091c86 100644 --- a/api/api.go +++ b/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/dsa" "crypto/ecdsa" "crypto/rsa" @@ -26,9 +27,10 @@ import ( // Authority is the interface implemented by a CA authority. type Authority interface { + SSHAuthority // NOTE: Authorize will be deprecated in future releases. Please use the - // context specific Authoirize[Sign|Revoke|etc.] methods. - Authorize(ott string) ([]provisioner.SignOption, error) + // context specific Authorize[Sign|Revoke|etc.] methods. + Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) GetTLSOptions() *tlsutil.TLSOptions Root(shasum string) (*x509.Certificate, error) @@ -249,6 +251,8 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("GET", "/federation", h.Federation) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) + // SSH CA + r.MethodFunc("POST", "/sign-ssh", h.SignSSH) } // Health is an HTTP handler that returns the status of the server. diff --git a/api/api_test.go b/api/api_test.go index 88110314..5ece5cc9 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -31,6 +31,7 @@ import ( "github.com/smallstep/certificates/logging" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/jose" + "golang.org/x/crypto/ssh" ) const ( @@ -424,7 +425,7 @@ type mockProvisioner struct { getEncryptedKey func() (string, string, bool) init func(provisioner.Config) error authorizeRevoke func(ott string) error - authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewal func(*x509.Certificate) error } @@ -480,9 +481,9 @@ func (m *mockProvisioner) AuthorizeRevoke(ott string) error { return m.err } -func (m *mockProvisioner) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { +func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.authorizeSign != nil { - return m.authorizeSign(ott) + return m.authorizeSign(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } @@ -501,6 +502,8 @@ type mockAuthority struct { getTLSOptions func() *tlsutil.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) + signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) @@ -511,7 +514,7 @@ type mockAuthority struct { } // TODO: remove once Authorize is deprecated. -func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) { +func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return m.AuthorizeSign(ott) } @@ -543,6 +546,20 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err } +func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) { if m.renew != nil { return m.renew(cert) diff --git a/api/ssh.go b/api/ssh.go new file mode 100644 index 00000000..7bcae7cf --- /dev/null +++ b/api/ssh.go @@ -0,0 +1,159 @@ +package api + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "golang.org/x/crypto/ssh" +) + +// SSHAuthority is the interface implemented by a SSH CA authority. +type SSHAuthority interface { + SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) +} + +// SignSSHRequest is the request body of an SSH certificate request. +type SignSSHRequest struct { + PublicKey []byte `json:"publicKey"` //base64 encoded + OTT string `json:"ott"` + CertType string `json:"certType,omitempty"` + Principals []string `json:"principals,omitempty"` + ValidAfter TimeDuration `json:"validAfter,omitempty"` + ValidBefore TimeDuration `json:"validBefore,omitempty"` + AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"` +} + +// SignSSHResponse is the response object that returns the SSH certificate. +type SignSSHResponse struct { + Certificate SSHCertificate `json:"crt"` + AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"` +} + +// SSHCertificate represents the response SSH certificate. +type SSHCertificate struct { + *ssh.Certificate `json:"omitempty"` +} + +// MarshalJSON implements the json.Marshaler interface. Returns a quoted, +// base64 encoded, openssh wire format version of the certificate. +func (c SSHCertificate) MarshalJSON() ([]byte, error) { + if c.Certificate == nil { + return []byte("null"), nil + } + s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal()) + return []byte(`"` + s + `"`), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is +// expected to be a quoted, base64 encoded, openssh wire formatted block of bytes. +func (c *SSHCertificate) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrap(err, "error decoding certificate") + } + if s == "" { + c.Certificate = nil + return nil + } + certData, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return errors.Wrap(err, "error decoding ssh certificate") + } + pub, err := ssh.ParsePublicKey(certData) + if err != nil { + return errors.Wrap(err, "error parsing ssh certificate") + } + cert, ok := pub.(*ssh.Certificate) + if !ok { + return errors.Errorf("error decoding ssh certificate: %T is not an *ssh.Certificate", pub) + } + c.Certificate = cert + return nil +} + +// Validate validates the SignSSHRequest. +func (s *SignSSHRequest) Validate() error { + switch { + case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: + return errors.Errorf("unknown certType %s", s.CertType) + case len(s.PublicKey) == 0: + return errors.New("missing or empty publicKey") + case len(s.OTT) == 0: + return errors.New("missing or empty ott") + default: + return nil + } +} + +// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token +// (ott) from the body and creates a new SSH certificate with the information in +// the request. +func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { + var body SignSSHRequest + if err := ReadJSON(r.Body, &body); err != nil { + WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + return + } + + logOtt(w, body.OTT) + if err := body.Validate(); err != nil { + WriteError(w, BadRequest(err)) + return + } + + publicKey, err := ssh.ParsePublicKey(body.PublicKey) + if err != nil { + WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey"))) + return + } + + var addUserPublicKey ssh.PublicKey + if body.AddUserPublicKey != nil { + addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) + if err != nil { + WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey"))) + return + } + } + + opts := provisioner.SSHOptions{ + CertType: body.CertType, + Principals: body.Principals, + ValidBefore: body.ValidBefore, + ValidAfter: body.ValidAfter, + } + + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod) + signOpts, err := h.Authority.Authorize(ctx, body.OTT) + if err != nil { + WriteError(w, Unauthorized(err)) + return + } + + cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + var addUserCertificate *SSHCertificate + if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { + addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) + if err != nil { + WriteError(w, Forbidden(err)) + return + } + addUserCertificate = &SSHCertificate{addUserCert} + } + + w.WriteHeader(http.StatusCreated) + JSON(w, &SignSSHResponse{ + Certificate: SSHCertificate{cert}, + AddUserCertificate: addUserCertificate, + }) +} diff --git a/api/ssh_test.go b/api/ssh_test.go new file mode 100644 index 00000000..f37bcad8 --- /dev/null +++ b/api/ssh_test.go @@ -0,0 +1,327 @@ +package api + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/logging" + "golang.org/x/crypto/ssh" +) + +var ( + sshSignerKey = mustKey() + sshUserKey = mustKey() + sshHostKey = mustKey() +) + +func mustKey() *ecdsa.PrivateKey { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + return priv +} + +func signSSHCertificate(cert *ssh.Certificate) error { + signerKey, err := ssh.NewPublicKey(sshSignerKey.Public()) + if err != nil { + return err + } + signer, err := ssh.NewSignerFromSigner(sshSignerKey) + if err != nil { + return err + } + cert.SignatureKey = signerKey + data := cert.Marshal() + data = data[:len(data)-4] + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return err + } + cert.Signature = sig + return nil +} + +func getSignedUserCertificate() (*ssh.Certificate, error) { + key, err := ssh.NewPublicKey(sshUserKey.Public()) + if err != nil { + return nil, err + } + t := time.Now() + cert := &ssh.Certificate{ + Nonce: []byte("1234567890"), + Key: key, + Serial: 1234567890, + CertType: ssh.UserCert, + KeyId: "user@localhost", + ValidPrincipals: []string{"user"}, + ValidAfter: uint64(t.Unix()), + ValidBefore: uint64(t.Add(time.Hour).Unix()), + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + }, + }, + Reserved: []byte{}, + } + if err := signSSHCertificate(cert); err != nil { + return nil, err + } + return cert, nil +} + +func getSignedHostCertificate() (*ssh.Certificate, error) { + key, err := ssh.NewPublicKey(sshHostKey.Public()) + if err != nil { + return nil, err + } + t := time.Now() + cert := &ssh.Certificate{ + Nonce: []byte("1234567890"), + Key: key, + Serial: 1234567890, + CertType: ssh.UserCert, + KeyId: "internal.smallstep.com", + ValidPrincipals: []string{"internal.smallstep.com"}, + ValidAfter: uint64(t.Unix()), + ValidBefore: uint64(t.Add(time.Hour).Unix()), + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + Reserved: []byte{}, + } + if err := signSSHCertificate(cert); err != nil { + return nil, err + } + return cert, nil +} + +func TestSSHCertificate_MarshalJSON(t *testing.T) { + user, err := getSignedUserCertificate() + assert.FatalError(t, err) + host, err := getSignedHostCertificate() + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + type fields struct { + Certificate *ssh.Certificate + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + {"nil", fields{Certificate: nil}, []byte("null"), false}, + {"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false}, + {"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := SSHCertificate{ + Certificate: tt.fields.Certificate, + } + got, err := c.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSSHCertificate_UnmarshalJSON(t *testing.T) { + user, err := getSignedUserCertificate() + assert.FatalError(t, err) + host, err := getSignedHostCertificate() + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal()) + + type args struct { + data []byte + } + tests := []struct { + name string + args args + want *ssh.Certificate + wantErr bool + }{ + {"null", args{[]byte(`null`)}, nil, false}, + {"empty", args{[]byte(`""`)}, nil, false}, + {"user", args{[]byte(`"` + userB64 + `"`)}, user, false}, + {"host", args{[]byte(`"` + hostB64 + `"`)}, host, false}, + {"bad-string", args{[]byte(userB64)}, nil, true}, + {"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true}, + {"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true}, + {"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &SSHCertificate{} + if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(tt.want, c.Certificate) { + t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want) + } + }) + } +} + +func TestSignSSHRequest_Validate(t *testing.T) { + type fields struct { + PublicKey []byte + OTT string + CertType string + Principals []string + ValidAfter TimeDuration + ValidBefore TimeDuration + AddUserPublicKey []byte + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, + {"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, + {"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, + {"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SignSSHRequest{ + PublicKey: tt.fields.PublicKey, + OTT: tt.fields.OTT, + CertType: tt.fields.CertType, + Principals: tt.fields.Principals, + ValidAfter: tt.fields.ValidAfter, + ValidBefore: tt.fields.ValidBefore, + AddUserPublicKey: tt.fields.AddUserPublicKey, + } + if err := s.Validate(); (err != nil) != tt.wantErr { + t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_caHandler_SignSSH(t *testing.T) { + user, err := getSignedUserCertificate() + assert.FatalError(t, err) + host, err := getSignedHostCertificate() + assert.FatalError(t, err) + + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + userReq, err := json.Marshal(SignSSHRequest{ + PublicKey: user.Key.Marshal(), + OTT: "ott", + }) + assert.FatalError(t, err) + hostReq, err := json.Marshal(SignSSHRequest{ + PublicKey: host.Key.Marshal(), + OTT: "ott", + }) + assert.FatalError(t, err) + userAddReq, err := json.Marshal(SignSSHRequest{ + PublicKey: user.Key.Marshal(), + OTT: "ott", + AddUserPublicKey: user.Key.Marshal(), + }) + assert.FatalError(t, err) + + type fields struct { + Authority Authority + } + type args struct { + w http.ResponseWriter + r *http.Request + } + tests := []struct { + name string + req []byte + authErr error + signCert *ssh.Certificate + signErr error + addUserCert *ssh.Certificate + addUserErr error + body []byte + statusCode int + }{ + {"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, + {"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, + {"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, + {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized}, + {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, + {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + return []provisioner.SignOption{}, tt.authErr + }, + signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + return tt.signCert, tt.signErr + }, + signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + return tt.addUserCert, tt.addUserErr + }, + }).(*caHandler) + + req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req)) + w := httptest.NewRecorder() + h.SignSSH(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Root unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body) + } + } + }) + } +} diff --git a/authority/authority.go b/authority/authority.go index 33340029..848a4f63 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -1,12 +1,14 @@ package authority import ( + "crypto" "crypto/sha256" "crypto/x509" "encoding/hex" "sync" "time" + "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/cli/crypto/pemutil" @@ -20,6 +22,8 @@ type Authority struct { config *Config rootX509Certs []*x509.Certificate intermediateIdentity *x509util.Identity + sshCAUserCertSignKey crypto.Signer + sshCAHostCertSignKey crypto.Signer validateOnce bool certificates *sync.Map startTime time.Time @@ -117,6 +121,22 @@ func (a *Authority) init() error { } } + // Decrypt and load SSH keys + if a.config.SSH != nil { + if a.config.SSH.HostKey != "" { + a.sshCAHostCertSignKey, err = parseCryptoSigner(a.config.SSH.HostKey, a.config.Password) + if err != nil { + return err + } + } + if a.config.SSH.UserKey != "" { + a.sshCAUserCertSignKey, err = parseCryptoSigner(a.config.SSH.UserKey, a.config.Password) + if err != nil { + return err + } + } + } + // Store all the provisioners for _, p := range a.config.AuthorityConfig.Provisioners { if err := a.provisioners.Store(p); err != nil { @@ -143,3 +163,19 @@ func (a *Authority) GetDatabase() db.AuthDB { func (a *Authority) Shutdown() error { return a.db.Shutdown() } + +func parseCryptoSigner(filename, password string) (crypto.Signer, error) { + var opts []pemutil.Options + if password != "" { + opts = append(opts, pemutil.WithPassword([]byte(password))) + } + key, err := pemutil.Read(filename, opts...) + if err != nil { + return nil, err + } + signer, ok := key.(crypto.Signer) + if !ok { + return nil, errors.Errorf("key %s of type %T cannot be used for signing operations", filename, key) + } + return signer, nil +} diff --git a/authority/authorize.go b/authority/authorize.go index 1a7e45d3..b8f7cb6f 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto/x509" "net/http" "strings" @@ -72,33 +73,51 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) { return p, nil } -// Authorize is a passthrough to AuthorizeSign. -// NOTE: Authorize will be deprecated in a future release. Please use the -// context specific Authorize[Sign|Revoke|etc.] going forwards. -func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) { - return a.AuthorizeSign(ott) +// Authorize grabs the method from the context and authorizes a signature +// request by validating the one-time-token. +func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + var errContext = apiCtx{"ott": ott} + switch m := provisioner.MethodFromContext(ctx); m { + case provisioner.SignMethod: + return a.authorizeSign(ctx, ott) + case provisioner.SignSSHMethod: + if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { + return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} + } + return a.authorizeSign(ctx, ott) + case provisioner.RevokeMethod: + return nil, &apiError{errors.New("authorize: revoke method is not supported"), http.StatusInternalServerError, errContext} + default: + return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext} + } } -// AuthorizeSign authorizes a signature request by validating and authenticating -// a OTT that must be sent w/ the request. -func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - var errContext = context{"ott": ott} - +// authorizeSign loads the provisioner from the token, checks that it has not +// been used again and calls the provisioner AuthorizeSign method. Returns a +// list of methods to apply to the signing flow. +func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + var errContext = apiCtx{"ott": ott} p, err := a.authorizeToken(ott) if err != nil { return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext} } - - // Call the provisioner AuthorizeSign method to apply provisioner specific - // auth claims and get the signing options. - opts, err := p.AuthorizeSign(ott) + opts, err := p.AuthorizeSign(ctx, ott) if err != nil { return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext} } - return opts, nil } +// AuthorizeSign authorizes a signature request by validating and authenticating +// a OTT that must be sent w/ the request. +// +// NOTE: This method is deprecated and should not be used. We make it available +// in the short term os as not to break existing clients. +func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + return a.Authorize(ctx, ott) +} + // authorizeRevoke authorizes a revocation request by validating and authenticating // the RevokeOptions POSTed with the request. // Returns a tuple of the provisioner ID and error, if one occurred. diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 27a15513..23a2983c 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -1,11 +1,14 @@ package authority import ( + "context" "crypto/x509" "net/http" "testing" "time" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/cli/crypto/pemutil" @@ -72,7 +75,7 @@ func TestAuthority_authorizeToken(t *testing.T) { auth: a, ott: "foo", err: &apiError{errors.New("authorizeToken: error parsing token"), - http.StatusUnauthorized, context{"ott": "foo"}}, + http.StatusUnauthorized, apiCtx{"ott": "foo"}}, } }, "fail/prehistoric-token": func(t *testing.T) *authorizeTest { @@ -91,7 +94,7 @@ func TestAuthority_authorizeToken(t *testing.T) { auth: a, ott: raw, err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"), - http.StatusUnauthorized, context{"ott": raw}}, + http.StatusUnauthorized, apiCtx{"ott": raw}}, } }, "fail/provisioner-not-found": func(t *testing.T) *authorizeTest { @@ -113,7 +116,7 @@ func TestAuthority_authorizeToken(t *testing.T) { auth: a, ott: raw, err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"), - http.StatusUnauthorized, context{"ott": raw}}, + http.StatusUnauthorized, apiCtx{"ott": raw}}, } }, "ok/simpledb": func(t *testing.T) *authorizeTest { @@ -150,7 +153,7 @@ func TestAuthority_authorizeToken(t *testing.T) { auth: _a, ott: raw, err: &apiError{errors.New("authorizeToken: token already used"), - http.StatusUnauthorized, context{"ott": raw}}, + http.StatusUnauthorized, apiCtx{"ott": raw}}, } }, "ok/mockNoSQLDB": func(t *testing.T) *authorizeTest { @@ -198,7 +201,7 @@ func TestAuthority_authorizeToken(t *testing.T) { auth: _a, ott: raw, err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"), - http.StatusInternalServerError, context{"ott": raw}}, + http.StatusInternalServerError, apiCtx{"ott": raw}}, } }, "fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest { @@ -223,7 +226,7 @@ func TestAuthority_authorizeToken(t *testing.T) { auth: _a, ott: raw, err: &apiError{errors.New("authorizeToken: token already used"), - http.StatusUnauthorized, context{"ott": raw}}, + http.StatusUnauthorized, apiCtx{"ott": raw}}, } }, } @@ -388,7 +391,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) { auth: a, ott: "foo", err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"), - http.StatusUnauthorized, context{"ott": "foo"}}, + http.StatusUnauthorized, apiCtx{"ott": "foo"}}, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { @@ -406,7 +409,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) { auth: a, ott: raw, err: &apiError{errors.New("authorizeSign: token subject cannot be empty"), - http.StatusUnauthorized, context{"ott": raw}}, + http.StatusUnauthorized, apiCtx{"ott": raw}}, } }, "ok": func(t *testing.T) *authorizeTest { @@ -480,7 +483,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, ott: "foo", err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"), - http.StatusUnauthorized, context{"ott": "foo"}}, + http.StatusUnauthorized, apiCtx{"ott": "foo"}}, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { @@ -498,7 +501,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, ott: raw, err: &apiError{errors.New("authorizeSign: token subject cannot be empty"), - http.StatusUnauthorized, context{"ott": raw}}, + http.StatusUnauthorized, apiCtx{"ott": raw}}, } }, "ok": func(t *testing.T) *authorizeTest { @@ -522,8 +525,8 @@ func TestAuthority_Authorize(t *testing.T) { for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - - got, err := tc.auth.Authorize(tc.ott) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + got, err := tc.auth.Authorize(ctx, tc.ott) if err != nil { if assert.NotNil(t, tc.err) { assert.Nil(t, got) @@ -573,7 +576,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) { auth: a, crt: fooCrt, err: &apiError{errors.New("renew: force"), - http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}}, + http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}}, } }, "fail/revoked": func(t *testing.T) *authorizeTest { @@ -587,7 +590,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) { auth: a, crt: fooCrt, err: &apiError{errors.New("renew: certificate has been revoked"), - http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}}, + http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}}, } }, "fail/load-provisioner": func(t *testing.T) *authorizeTest { @@ -601,7 +604,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) { auth: a, crt: otherCrt, err: &apiError{errors.New("renew: provisioner not found"), - http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}}, + http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}}, } }, "fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest { @@ -616,7 +619,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) { auth: a, crt: renewDisabledCrt, err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), - http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}}, + http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}}, } }, "ok": func(t *testing.T) *authorizeTest { diff --git a/authority/config.go b/authority/config.go index 77854812..99fdf457 100644 --- a/authority/config.go +++ b/authority/config.go @@ -28,11 +28,19 @@ var ( Renegotiation: false, } defaultDisableRenewal = false + defaultEnableSSHCA = false globalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 4 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, } ) @@ -44,6 +52,7 @@ type Config struct { IntermediateKey string `json:"key"` Address string `json:"address"` DNSNames []string `json:"dnsNames"` + SSH *SSHConfig `json:"ssh,omitempty"` Logger json.RawMessage `json:"logger,omitempty"` DB *db.Config `json:"db,omitempty"` Monitoring json.RawMessage `json:"monitoring,omitempty"` @@ -92,6 +101,14 @@ func (c *AuthConfig) Validate(audiences provisioner.Audiences) error { return nil } +// SSHConfig contains the user and host keys. +type SSHConfig struct { + HostKey string `json:"hostKey"` + UserKey string `json:"userKey"` + AddUserPrincipal string `json:"addUserPrincipal"` + AddUserCommand string `json:"addUserCommand"` +} + // LoadConfiguration parses the given filename in JSON format and returns the // configuration struct. func LoadConfiguration(filename string) (*Config, error) { diff --git a/authority/error.go b/authority/error.go index 056d3147..85293f20 100644 --- a/authority/error.go +++ b/authority/error.go @@ -4,13 +4,13 @@ import ( "net/http" ) -type context map[string]interface{} +type apiCtx map[string]interface{} // Error implements the api.Error interface and adds context to error messages. type apiError struct { err error code int - context context + context apiCtx } // Cause implements the errors.Causer interface and returns the original error. diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 421b4af8..dae817f8 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/sha256" "crypto/x509" "encoding/base64" @@ -266,13 +267,21 @@ func (p *AWS) Init(config Config) (err error) { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) { +func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { payload, err := p.authorizeToken(token) if err != nil { return nil, err } - doc := payload.document + // Check for the sign ssh method, default to sign X.509 + if m := MethodFromContext(ctx); m == SignSSHMethod { + if p.claimer.IsSSHCAEnabled() == false { + return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + } + return p.authorizeSSHSign(payload) + } + + doc := payload.document // Enforce known CN and default DNS and IP if configured. // By default we'll accept the CN and SANs in the CSR. // There's no way to trust them other than TOFU. @@ -433,3 +442,35 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { payload.document = doc return &payload, nil } + +// authorizeSSHSign returns the list of SignOption for a SignSSH request. +func (p *AWS) authorizeSSHSign(claims *awsPayload) ([]SignOption, error) { + doc := claims.document + + signOptions := []SignOption{ + // set the key id to the token subject + sshCertificateKeyIDModifier(claims.Subject), + } + + // Default to host + known IPs/hostnames + defaults := SSHOptions{ + CertType: SSHHostCert, + Principals: []string{ + doc.PrivateIP, + fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region), + }, + } + // Validate user options + signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + // Set defaults if not given as user options + signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + + return append(signOptions, + // set the default extensions + &sshDefaultExtensionModifier{}, + // checks the validity bounds, and set the validity if has not been set + &sshCertificateValidityModifier{p.claimer}, + // require all the fields in the SSH certificate + &sshCertificateDefaultValidator{}, + ), nil +} diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 98c885fa..4fe9367a 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -1,6 +1,8 @@ package provisioner import ( + "context" + "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" @@ -347,7 +349,8 @@ func TestAWS_AuthorizeSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.aws.AuthorizeSign(tt.args.token) + ctx := NewContextWithMethod(context.Background(), SignMethod) + got, err := tt.aws.AuthorizeSign(ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return @@ -357,6 +360,84 @@ func TestAWS_AuthorizeSign(t *testing.T) { } } +func TestAWS_AuthorizeSign_SSH(t *testing.T) { + tm, fn := mockNow() + defer fn() + + p1, srv, err := generateAWSWithServer() + assert.FatalError(t, err) + defer srv.Close() + + t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com") + assert.FatalError(t, err) + + key, err := generateJSONWebKey() + assert.FatalError(t, err) + + signer, err := generateJSONWebKey() + assert.FatalError(t, err) + + hostDuration := p1.claimer.DefaultHostSSHCertDuration() + expectedHostOptions := &SSHOptions{ + CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + expectedHostOptionsIP := &SSHOptions{ + CertType: "host", Principals: []string{"127.0.0.1"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + expectedHostOptionsHostname := &SSHOptions{ + CertType: "host", Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + + type args struct { + token string + sshOpts SSHOptions + } + tests := []struct { + name string + aws *AWS + args args + expected *SSHOptions + wantErr bool + wantSignErr bool + }{ + {"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false}, + {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false}, + {"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}}, expectedHostOptionsIP, false, false}, + {"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptionsHostname, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false}, + {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true}, + {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true}, + {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}}, nil, false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), SignSSHMethod) + got, err := tt.aws.AuthorizeSign(ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else if assert.NotNil(t, got) { + cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) + if (err != nil) != tt.wantSignErr { + t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) + } else { + if tt.wantSignErr { + assert.Nil(t, cert) + } else { + assert.NoError(t, validateSSHCertificate(cert, tt.expected)) + } + } + } + }) + } +} func TestAWS_AuthorizeRenewal(t *testing.T) { p1, err := generateAWS() assert.FatalError(t, err) diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 40d14c9c..7619d202 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/sha256" "crypto/x509" "encoding/hex" @@ -209,7 +210,7 @@ func (p *Azure) Init(config Config) (err error) { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) { +func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, errors.Wrapf(err, "error parsing token") @@ -264,6 +265,14 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) { } } + // Check for the sign ssh method, default to sign X.509 + if m := MethodFromContext(ctx); m == SignSSHMethod { + if p.claimer.IsSSHCAEnabled() == false { + return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + } + return p.authorizeSSHSign(claims, name) + } + // Enforce known common name and default DNS if configured. // By default we'll accept the CN and SANs in the CSR. // There's no way to trust them other than TOFU. @@ -296,6 +305,33 @@ func (p *Azure) AuthorizeRevoke(token string) error { return errors.New("revoke is not supported on a Azure provisioner") } +// authorizeSSHSign returns the list of SignOption for a SignSSH request. +func (p *Azure) authorizeSSHSign(claims azurePayload, name string) ([]SignOption, error) { + signOptions := []SignOption{ + // set the key id to the token subject + sshCertificateKeyIDModifier(name), + } + + // Default to host + known hostnames + defaults := SSHOptions{ + CertType: SSHHostCert, + Principals: []string{name}, + } + // Validate user options + signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + // Set defaults if not given as user options + signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + + return append(signOptions, + // set the default extensions + &sshDefaultExtensionModifier{}, + // checks the validity bounds, and set the validity if has not been set + &sshCertificateValidityModifier{p.claimer}, + // require all the fields in the SSH certificate + &sshCertificateDefaultValidator{}, + ), nil +} + // assertConfig initializes the config if it has not been initialized func (p *Azure) assertConfig() { if p.config == nil { diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 8fbc9b20..21ebb59d 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -1,6 +1,8 @@ package provisioner import ( + "context" + "crypto" "crypto/sha256" "crypto/x509" "encoding/hex" @@ -295,7 +297,8 @@ func TestAzure_AuthorizeSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.azure.AuthorizeSign(tt.args.token) + ctx := NewContextWithMethod(context.Background(), SignMethod) + got, err := tt.azure.AuthorizeSign(ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return @@ -305,6 +308,75 @@ func TestAzure_AuthorizeSign(t *testing.T) { } } +func TestAzure_AuthorizeSign_SSH(t *testing.T) { + tm, fn := mockNow() + defer fn() + + p1, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + + t1, err := p1.GetIdentityToken("subject", "caURL") + assert.FatalError(t, err) + + key, err := generateJSONWebKey() + assert.FatalError(t, err) + + signer, err := generateJSONWebKey() + assert.FatalError(t, err) + + hostDuration := p1.claimer.DefaultHostSSHCertDuration() + expectedHostOptions := &SSHOptions{ + CertType: "host", Principals: []string{"virtualMachine"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + + type args struct { + token string + sshOpts SSHOptions + } + tests := []struct { + name string + azure *Azure + args args + expected *SSHOptions + wantErr bool + wantSignErr bool + }{ + {"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false}, + {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false}, + {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true}, + {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true}, + {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}}, nil, false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), SignSSHMethod) + got, err := tt.azure.AuthorizeSign(ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else if assert.NotNil(t, got) { + cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) + if (err != nil) != tt.wantSignErr { + t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) + } else { + if tt.wantSignErr { + assert.Nil(t, cert) + } else { + assert.NoError(t, validateSSHCertificate(cert, tt.expected)) + } + } + } + }) + } +} + func TestAzure_AuthorizeRenewal(t *testing.T) { p1, err := generateAzure() assert.FatalError(t, err) diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 1109e0c7..4eba5ad7 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -8,10 +8,19 @@ import ( // Claims so that individual provisioners can override global claims. type Claims struct { + // TLS CA properties MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"` + // SSH CA properties + MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` + MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` + DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"` + MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"` + MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` + DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` + EnableSSHCA *bool `json:"enableSSHCA,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -30,11 +39,19 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() + enableSSHCA := c.IsSSHCAEnabled() return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - DisableRenewal: &disableRenewal, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + DisableRenewal: &disableRenewal, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, + EnableSSHCA: &enableSSHCA, } } @@ -78,6 +95,76 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } +// DefaultUserSSHCertDuration returns the default SSH user cert duration for the +// provisioner. If the default is not set within the provisioner, then the +// global default from the authority configuration will be used. +func (c *Claimer) DefaultUserSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.DefaultUserSSHDur == nil { + return c.global.DefaultUserSSHDur.Duration + } + return c.claims.DefaultUserSSHDur.Duration +} + +// MinUserSSHCertDuration returns the minimum SSH user cert duration for the +// provisioner. If the minimum is not set within the provisioner, then the +// global minimum from the authority configuration will be used. +func (c *Claimer) MinUserSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MinUserSSHDur == nil { + return c.global.MinUserSSHDur.Duration + } + return c.claims.MinUserSSHDur.Duration +} + +// MaxUserSSHCertDuration returns the maximum SSH user cert duration for the +// provisioner. If the maximum is not set within the provisioner, then the +// global maximum from the authority configuration will be used. +func (c *Claimer) MaxUserSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MaxUserSSHDur == nil { + return c.global.MaxUserSSHDur.Duration + } + return c.claims.MaxUserSSHDur.Duration +} + +// DefaultHostSSHCertDuration returns the default SSH host cert duration for the +// provisioner. If the default is not set within the provisioner, then the +// global default from the authority configuration will be used. +func (c *Claimer) DefaultHostSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.DefaultHostSSHDur == nil { + return c.global.DefaultHostSSHDur.Duration + } + return c.claims.DefaultHostSSHDur.Duration +} + +// MinHostSSHCertDuration returns the minimum SSH host cert duration for the +// provisioner. If the minimum is not set within the provisioner, then the +// global minimum from the authority configuration will be used. +func (c *Claimer) MinHostSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MinHostSSHDur == nil { + return c.global.MinHostSSHDur.Duration + } + return c.claims.MinHostSSHDur.Duration +} + +// MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the +// provisioner. If the maximum is not set within the provisioner, then the +// global maximum from the authority configuration will be used. +func (c *Claimer) MaxHostSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MaxHostSSHDur == nil { + return c.global.MaxHostSSHDur.Duration + } + return c.claims.MaxHostSSHDur.Duration +} + +// IsSSHCAEnabled returns if the SSH CA is enabled for the provisioner. If the +// property is not set within the provisioner, then the global value from the +// authority configuration will be used. +func (c *Claimer) IsSSHCAEnabled() bool { + if c.claims == nil || c.claims.EnableSSHCA == nil { + return *c.global.EnableSSHCA + } + return *c.claims.EnableSSHCA +} + // Validate validates and modifies the Claims with default values. func (c *Claimer) Validate() error { var ( diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 5ee92237..bb8c4ede 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -2,6 +2,7 @@ package provisioner import ( "bytes" + "context" "crypto/sha256" "crypto/x509" "encoding/hex" @@ -205,13 +206,21 @@ func (p *GCP) Init(config Config) error { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) { +func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token) if err != nil { return nil, err } - ce := claims.Google.ComputeEngine + // Check for the sign ssh method, default to sign X.509 + if m := MethodFromContext(ctx); m == SignSSHMethod { + if p.claimer.IsSSHCAEnabled() == false { + return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + } + return p.authorizeSSHSign(claims) + } + + ce := claims.Google.ComputeEngine // Enforce known common name and default DNS if configured. // By default we we'll accept the CN and SANs in the CSR. // There's no way to trust them other than TOFU. @@ -345,3 +354,35 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { return &claims, nil } + +// authorizeSSHSign returns the list of SignOption for a SignSSH request. +func (p *GCP) authorizeSSHSign(claims *gcpPayload) ([]SignOption, error) { + ce := claims.Google.ComputeEngine + + signOptions := []SignOption{ + // set the key id to the token subject + sshCertificateKeyIDModifier(ce.InstanceName), + } + + // Default to host + known hostnames + defaults := SSHOptions{ + CertType: SSHHostCert, + Principals: []string{ + fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID), + fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID), + }, + } + // Validate user options + signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + // Set defaults if not given as user options + signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + + return append(signOptions, + // set the default extensions + &sshDefaultExtensionModifier{}, + // checks the validity bounds, and set the validity if has not been set + &sshCertificateValidityModifier{p.claimer}, + // require all the fields in the SSH certificate + &sshCertificateDefaultValidator{}, + ), nil +} diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 04dffd95..077537a1 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -1,6 +1,8 @@ package provisioner import ( + "context" + "crypto" "crypto/sha256" "crypto/x509" "encoding/hex" @@ -330,7 +332,8 @@ func TestGCP_AuthorizeSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.gcp.AuthorizeSign(tt.args.token) + ctx := NewContextWithMethod(context.Background(), SignMethod) + got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return @@ -340,6 +343,87 @@ func TestGCP_AuthorizeSign(t *testing.T) { } } +func TestGCP_AuthorizeSign_SSH(t *testing.T) { + tm, fn := mockNow() + defer fn() + + p1, err := generateGCP() + assert.FatalError(t, err) + + t1, err := generateGCPToken(p1.ServiceAccounts[0], + "https://accounts.google.com", p1.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now(), &p1.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + + key, err := generateJSONWebKey() + assert.FatalError(t, err) + + signer, err := generateJSONWebKey() + assert.FatalError(t, err) + + hostDuration := p1.claimer.DefaultHostSSHCertDuration() + expectedHostOptions := &SSHOptions{ + CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + expectedHostOptionsPrincipal1 := &SSHOptions{ + CertType: "host", Principals: []string{"instance-name.c.project-id.internal"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + expectedHostOptionsPrincipal2 := &SSHOptions{ + CertType: "host", Principals: []string{"instance-name.zone.c.project-id.internal"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + + type args struct { + token string + sshOpts SSHOptions + } + tests := []struct { + name string + gcp *GCP + args args + expected *SSHOptions + wantErr bool + wantSignErr bool + }{ + {"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false}, + {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false}, + {"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}}, expectedHostOptionsPrincipal1, false, false}, + {"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}}, expectedHostOptionsPrincipal2, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false}, + {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true}, + {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true}, + {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}}, nil, false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), SignSSHMethod) + got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else if assert.NotNil(t, got) { + cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) + if (err != nil) != tt.wantSignErr { + t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) + } else { + if tt.wantSignErr { + assert.Nil(t, cert) + } else { + assert.NoError(t, validateSSHCertificate(cert, tt.expected)) + } + } + } + }) + } +} + func TestGCP_AuthorizeRenewal(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index a4bc1137..9bf38f17 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/x509" "time" @@ -12,7 +13,12 @@ import ( // jwtPayload extends jwt.Claims with step attributes. type jwtPayload struct { jose.Claims - SANs []string `json:"sans,omitempty"` + SANs []string `json:"sans,omitempty"` + Step *stepPayload `json:"step,omitempty"` +} + +type stepPayload struct { + SSH *SSHOptions `json:"ssh,omitempty"` } // JWK is the default provisioner, an entity that can sign tokens necessary for @@ -129,11 +135,20 @@ func (p *JWK) AuthorizeRevoke(token string) error { } // AuthorizeSign validates the given token. -func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) { +func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.audiences.Sign) if err != nil { return nil, err } + + // Check for SSH token + if claims.Step != nil && claims.Step.SSH != nil { + if p.claimer.IsSSHCAEnabled() == false { + return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + } + return p.authorizeSSHSign(claims) + } + // NOTE: This is for backwards compatibility with older versions of cli // and certificates. Older versions added the token subject as the only SAN // in a CSR by default. @@ -161,3 +176,41 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error { } return nil } + +// authorizeSSHSign returns the list of SignOption for a SignSSH request. +func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) { + t := now() + opts := claims.Step.SSH + signOptions := []SignOption{ + // validates user's SSHOptions with the ones in the token + sshCertificateOptionsValidator(*opts), + // set the key id to the token subject + sshCertificateKeyIDModifier(claims.Subject), + } + + // Add modifiers from custom claims + if opts.CertType != "" { + signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType)) + } + if len(opts.Principals) > 0 { + signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals)) + } + if !opts.ValidAfter.IsZero() { + signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) + } + if !opts.ValidBefore.IsZero() { + signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) + } + + // Default to a user certificate with no principals if not set + signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) + + return append(signOptions, + // set the default extensions + &sshDefaultExtensionModifier{}, + // checks the validity bounds, and set the validity if has not been set + &sshCertificateValidityModifier{p.claimer}, + // require all the fields in the SSH certificate + &sshCertificateDefaultValidator{}, + ), nil +} diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 5e3ad5f7..a13db307 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -1,6 +1,8 @@ package provisioner import ( + "context" + "crypto" "crypto/x509" "errors" "strings" @@ -13,11 +15,19 @@ import ( var ( defaultDisableRenewal = false + defaultEnableSSHCA = true globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, } ) @@ -259,7 +269,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got, err := tt.prov.AuthorizeSign(tt.args.token); err != nil { + ctx := NewContextWithMethod(context.Background(), SignMethod) + if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -318,3 +329,201 @@ func TestJWK_AuthorizeRenewal(t *testing.T) { }) } } + +func TestJWK_AuthorizeSign_SSH(t *testing.T) { + tm, fn := mockNow() + defer fn() + + p1, err := generateJWK() + assert.FatalError(t, err) + jwk, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + + iss, aud := p1.Name, testAudiences.Sign[0] + + t1, err := generateSimpleSSHUserToken(iss, aud, jwk) + assert.FatalError(t, err) + + t2, err := generateSimpleSSHHostToken(iss, aud, jwk) + assert.FatalError(t, err) + + // invalid signature + failSig := t1[0 : len(t1)-2] + + key, err := generateJSONWebKey() + assert.FatalError(t, err) + + signer, err := generateJSONWebKey() + assert.FatalError(t, err) + + userDuration := p1.claimer.DefaultUserSSHCertDuration() + hostDuration := p1.claimer.DefaultHostSSHCertDuration() + expectedUserOptions := &SSHOptions{ + CertType: "user", Principals: []string{"name"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), + } + expectedHostOptions := &SSHOptions{ + CertType: "host", Principals: []string{"smallstep.com"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + + type args struct { + token string + sshOpts SSHOptions + } + tests := []struct { + name string + prov *JWK + args args + expected *SSHOptions + wantErr bool + wantSignErr bool + }{ + {"user", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false}, + {"user-type", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false}, + {"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false}, + {"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false}, + {"host", p1, args{t2, SSHOptions{}}, expectedHostOptions, false, false}, + {"host-type", p1, args{t2, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, + {"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false}, + {"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false}, + {"fail-signature", p1, args{failSig, SSHOptions{}}, nil, true, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), SignSSHMethod) + got, err := tt.prov.AuthorizeSign(ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else if assert.NotNil(t, got) { + cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) + if (err != nil) != tt.wantSignErr { + t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) + } else { + if tt.wantSignErr { + assert.Nil(t, cert) + } else { + assert.NoError(t, validateSSHCertificate(cert, tt.expected)) + } + } + } + }) + } +} + +func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { + tm, fn := mockNow() + defer fn() + + p1, err := generateJWK() + assert.FatalError(t, err) + jwk, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + + sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.Sign[0], time.Now() + + key, err := generateJSONWebKey() + assert.FatalError(t, err) + + signer, err := generateJSONWebKey() + assert.FatalError(t, err) + + userDuration := p1.claimer.DefaultUserSSHCertDuration() + hostDuration := p1.claimer.DefaultHostSSHCertDuration() + expectedUserOptions := &SSHOptions{ + CertType: "user", Principals: []string{"name"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), + } + expectedHostOptions := &SSHOptions{ + CertType: "host", Principals: []string{"smallstep.com"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + type args struct { + sub, iss, aud string + iat time.Time + tokSSHOpts *SSHOptions + userSSHOpts *SSHOptions + jwk *jose.JSONWebKey + } + tests := []struct { + name string + prov *JWK + args args + expected *SSHOptions + wantErr bool + wantSignErr bool + }{ + {"ok-user", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, expectedUserOptions, false, false}, + {"ok-host", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, &SSHOptions{}, jwk}, expectedHostOptions, false, false}, + {"ok-user-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "user", Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false}, + {"ok-host-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, jwk}, expectedHostOptions, false, false}, + {"ok-user-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user"}, &SSHOptions{Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false}, + {"ok-host-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{Principals: []string{"smallstep.com"}}, &SSHOptions{CertType: "host"}, jwk}, expectedHostOptions, false, false}, + {"ok-user-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, + }, &SSHOptions{ + ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), + }, jwk}, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), ValidBefore: NewTimeDuration(tm.Add(userDuration - time.Hour)), + }, false, false}, + {"ok-user-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, + }, &SSHOptions{ + ValidBefore: NewTimeDuration(tm.Add(time.Hour)), + }, jwk}, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), + }, false, false}, + {"ok-user-validAfter-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, + }, &SSHOptions{ + ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), + }, jwk}, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), + }, false, false}, + {"ok-user-match", p1, args{sub, iss, aud, iat, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)), + }, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)), + }, jwk}, &SSHOptions{ + CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), + }, false, false}, + {"fail-certType", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{CertType: "host"}, jwk}, nil, false, true}, + {"fail-principals", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{Principals: []string{"root"}}, jwk}, nil, false, true}, + {"fail-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm)}, &SSHOptions{ValidAfter: NewTimeDuration(tm.Add(time.Hour))}, jwk}, nil, false, true}, + {"fail-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidBefore: NewTimeDuration(tm.Add(time.Hour))}, &SSHOptions{ValidBefore: NewTimeDuration(tm.Add(10 * time.Hour))}, jwk}, nil, false, true}, + {"fail-subject", p1, args{"", iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false}, + {"fail-issuer", p1, args{sub, "invalid", aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false}, + {"fail-audience", p1, args{sub, iss, "invalid", iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false}, + {"fail-expired", p1, args{sub, iss, aud, iat.Add(-6 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false}, + {"fail-notBefore", p1, args{sub, iss, aud, iat.Add(5 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), SignSSHMethod) + token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk) + assert.FatalError(t, err) + if got, err := tt.prov.AuthorizeSign(ctx, token); (err != nil) != tt.wantErr { + t.Errorf("JWK.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) + } else if !tt.wantErr && assert.NotNil(t, got) { + var opts SSHOptions + if tt.args.userSSHOpts != nil { + opts = *tt.args.userSSHOpts + } + cert, err := signSSHCertificate(key.Public().Key, opts, got, signer.Key.(crypto.Signer)) + if (err != nil) != tt.wantSignErr { + t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) + } else { + if tt.wantSignErr { + assert.Nil(t, cert) + } else { + assert.NoError(t, validateSSHCertificate(cert, tt.expected)) + } + } + } + }) + } +} diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go new file mode 100644 index 00000000..c8f96885 --- /dev/null +++ b/authority/provisioner/method.go @@ -0,0 +1,34 @@ +package provisioner + +import ( + "context" +) + +// Method indicates the action to action that we will perform, it's used as part +// of the context in the call to authorize. It defaults to Sing. +type Method int + +// The key to save the Method in the context. +type methodKey struct{} + +const ( + // SignMethod is the method used to sign X.509 certificates. + SignMethod Method = iota + // SignSSHMethod is the method used to sign SSH certificate. + SignSSHMethod + // RevokeMethod is the method used to revoke X.509 certificates. + RevokeMethod +) + +// NewContextWithMethod creates a new context from ctx and attaches method to +// it. +func NewContextWithMethod(ctx context.Context, method Method) context.Context { + return context.WithValue(ctx, methodKey{}, method) +} + +// MethodFromContext returns the Method saved in ctx. Returns Sign if the given +// context has no Method associated with it. +func MethodFromContext(ctx context.Context) Method { + m, _ := ctx.Value(methodKey{}).(Method) + return m +} diff --git a/authority/provisioner/noop.go b/authority/provisioner/noop.go index 44fd4600..5bdc0677 100644 --- a/authority/provisioner/noop.go +++ b/authority/provisioner/noop.go @@ -1,6 +1,9 @@ package provisioner -import "crypto/x509" +import ( + "context" + "crypto/x509" +) // noop provisioners is a provisioner that accepts anything. type noop struct{} @@ -28,7 +31,7 @@ func (p *noop) Init(config Config) error { return nil } -func (p *noop) AuthorizeSign(token string) ([]SignOption, error) { +func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { return []SignOption{}, nil } diff --git a/authority/provisioner/noop_test.go b/authority/provisioner/noop_test.go index 8f25eb8c..a389b6b6 100644 --- a/authority/provisioner/noop_test.go +++ b/authority/provisioner/noop_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/x509" "testing" @@ -21,7 +22,8 @@ func Test_noop(t *testing.T) { assert.Equals(t, "", key) assert.Equals(t, false, ok) - sigOptions, err := p.AuthorizeSign("foo") + ctx := NewContextWithMethod(context.Background(), SignMethod) + sigOptions, err := p.AuthorizeSign(ctx, "foo") assert.Equals(t, []SignOption{}, sigOptions) assert.Equals(t, nil, err) } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 66e8b9b7..01b59625 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/x509" "encoding/json" "net/http" @@ -259,12 +260,29 @@ func (o *OIDC) AuthorizeRevoke(token string) error { } // AuthorizeSign validates the given token. -func (o *OIDC) AuthorizeSign(token string) ([]SignOption, error) { +func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := o.authorizeToken(token) if err != nil { return nil, err } + // Check for the sign ssh method, default to sign X.509 + if m := MethodFromContext(ctx); m == SignSSHMethod { + if o.claimer.IsSSHCAEnabled() == false { + return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID()) + } + return o.authorizeSSHSign(claims) + } + + // Admins should be able to authorize any SAN + if o.IsAdmin(claims.Email) { + return []SignOption{ + profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), + newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), + newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), + }, nil + } + so := []SignOption{ defaultPublicKeyValidator{}, profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), @@ -287,6 +305,42 @@ func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error { return nil } +// authorizeSSHSign returns the list of SignOption for a SignSSH request. +func (o *OIDC) authorizeSSHSign(claims *openIDPayload) ([]SignOption, error) { + signOptions := []SignOption{ + // set the key id to the token subject + sshCertificateKeyIDModifier(claims.Email), + } + + name := SanitizeSSHUserPrincipal(claims.Email) + if !sshUserRegex.MatchString(name) { + return nil, errors.Errorf("invalid principal '%s' from email address '%s'", name, claims.Email) + } + + // Admin users will default to user + name but they can be changed by the + // user options. Non-admins are only able to sign user certificates. + defaults := SSHOptions{ + CertType: SSHUserCert, + Principals: []string{name}, + } + + if !o.IsAdmin(claims.Email) { + signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + } + + // Default to a user with name as principal if not set + signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + + return append(signOptions, + // set the default extensions + &sshDefaultExtensionModifier{}, + // checks the validity bounds, and set the validity if has not been set + &sshCertificateValidityModifier{o.claimer}, + // require all the fields in the SSH certificate + &sshCertificateDefaultValidator{}, + ), nil +} + func getAndDecode(uri string, v interface{}) error { resp, err := http.Get(uri) if err != nil { diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 431bb7f8..9a756c5d 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -1,6 +1,8 @@ package provisioner import ( + "context" + "crypto" "crypto/x509" "fmt" "strings" @@ -276,7 +278,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.prov.AuthorizeSign(tt.args.token) + ctx := NewContextWithMethod(context.Background(), SignMethod) + got, err := tt.prov.AuthorizeSign(ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return @@ -286,7 +289,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) { } else { assert.NotNil(t, got) if tt.name == "admin" { - assert.Len(t, 4, got) + assert.Len(t, 3, got) } else { assert.Len(t, 5, got) } @@ -295,6 +298,117 @@ func TestOIDC_AuthorizeSign(t *testing.T) { } } +func TestOIDC_AuthorizeSign_SSH(t *testing.T) { + tm, fn := mockNow() + defer fn() + + srv := generateJWKServer(2) + defer srv.Close() + + var keys jose.JSONWebKeySet + assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys)) + + // Create test provisioners + p1, err := generateOIDC() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + p3, err := generateOIDC() + assert.FatalError(t, err) + // Admin + Domains + p3.Admins = []string{"name@smallstep.com", "root@example.com"} + p3.Domains = []string{"smallstep.com"} + + // Update configuration endpoints and initialize + config := Config{Claims: globalProvisionerClaims} + p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + assert.FatalError(t, p1.Init(config)) + assert.FatalError(t, p2.Init(config)) + assert.FatalError(t, p3.Init(config)) + + t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + // Admin email not in domains + okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + // Invalid email + failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + + key, err := generateJSONWebKey() + assert.FatalError(t, err) + + signer, err := generateJSONWebKey() + assert.FatalError(t, err) + + userDuration := p1.claimer.DefaultUserSSHCertDuration() + hostDuration := p1.claimer.DefaultHostSSHCertDuration() + expectedUserOptions := &SSHOptions{ + CertType: "user", Principals: []string{"name"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), + } + expectedAdminOptions := &SSHOptions{ + CertType: "user", Principals: []string{"root"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), + } + expectedHostOptions := &SSHOptions{ + CertType: "host", Principals: []string{"smallstep.com"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), + } + + type args struct { + token string + sshOpts SSHOptions + } + tests := []struct { + name string + prov *OIDC + args args + expected *SSHOptions + wantErr bool + wantSignErr bool + }{ + {"ok", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false}, + {"ok-user", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false}, + {"admin", p3, args{okAdmin, SSHOptions{}}, expectedAdminOptions, false, false}, + {"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}}, expectedAdminOptions, false, false}, + {"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}}, expectedAdminOptions, false, false}, + {"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false}, + {"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false}, + {"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}}, nil, false, true}, + {"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}}, nil, false, true}, + {"fail-email", p3, args{failEmail, SSHOptions{}}, nil, true, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContextWithMethod(context.Background(), SignSSHMethod) + got, err := tt.prov.AuthorizeSign(ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else if assert.NotNil(t, got) { + cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) + if (err != nil) != tt.wantSignErr { + t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) + } else { + if tt.wantSignErr { + assert.Nil(t, cert) + } else { + assert.NoError(t, validateSSHCertificate(cert, tt.expected)) + } + } + } + }) + } +} + func TestOIDC_AuthorizeRevoke(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 3118e6b0..248b93cf 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -1,9 +1,11 @@ package provisioner import ( + "context" "crypto/x509" "encoding/json" "net/url" + "regexp" "strings" "github.com/pkg/errors" @@ -17,7 +19,7 @@ type Interface interface { GetType() Type GetEncryptedKey() (kid string, key string, ok bool) Init(config Config) error - AuthorizeSign(token string) ([]SignOption, error) + AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) AuthorizeRenewal(cert *x509.Certificate) error AuthorizeRevoke(token string) error } @@ -169,3 +171,29 @@ func (l *List) UnmarshalJSON(data []byte) error { return nil } + +var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") + +// SanitizeSSHUserPrincipal grabs an email or a string with the format +// local@domain and returns a sanitized version of the local, valid to be used +// as a user name. If the email starts with a letter between a and z, the +// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. +func SanitizeSSHUserPrincipal(email string) string { + if i := strings.LastIndex(email, "@"); i >= 0 { + email = email[:i] + } + return strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= '0' && r <= '9': + return r + case r == '-': + return '-' + case r == '.': // drop dots + return -1 + default: + return '_' + } + }, strings.ToLower(email)) +} diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 11615e1a..d79c2b69 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -1,6 +1,8 @@ package provisioner -import "testing" +import ( + "testing" +) func TestType_String(t *testing.T) { tests := []struct { @@ -24,3 +26,29 @@ func TestType_String(t *testing.T) { }) } } + +func TestSanitizeSSHUserPrincipal(t *testing.T) { + type args struct { + email string + } + tests := []struct { + name string + args args + want string + }{ + {"simple", args{"foobar"}, "foobar"}, + {"camelcase", args{"FooBar"}, "foobar"}, + {"email", args{"foo@example.com"}, "foo"}, + {"email with dots", args{"foo.bar.zar@example.com"}, "foobarzar"}, + {"email with dashes", args{"foo-bar-zar@example.com"}, "foo-bar-zar"}, + {"email with underscores", args{"foo_bar_zar@example.com"}, "foo_bar_zar"}, + {"email with symbols", args{"Foo.Bar0123456789!#$%&'*+-/=?^_`{|}~;@example.com"}, "foobar0123456789________-___________"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := SanitizeSSHUserPrincipal(tt.args.email); got != tt.want { + t.Errorf("SanitizeSSHUserPrincipal() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go new file mode 100644 index 00000000..9ca2a95c --- /dev/null +++ b/authority/provisioner/sign_ssh_options.go @@ -0,0 +1,306 @@ +package provisioner + +import ( + "time" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +const ( + // SSHUserCert is the string used to represent ssh.UserCert. + SSHUserCert = "user" + + // SSHHostCert is the string used to represent ssh.HostCert. + SSHHostCert = "host" +) + +// SSHCertificateModifier is the interface used to change properties in an SSH +// certificate. +type SSHCertificateModifier interface { + SignOption + Modify(cert *ssh.Certificate) error +} + +// SSHCertificateOptionModifier is the interface used to add custom options used +// to modify the SSH certificate. +type SSHCertificateOptionModifier interface { + SignOption + Option(o SSHOptions) SSHCertificateModifier +} + +// SSHCertificateValidator is the interface used to validate an SSH certificate. +type SSHCertificateValidator interface { + SignOption + Valid(cert *ssh.Certificate) error +} + +// SSHCertificateOptionsValidator is the interface used to validate the custom +// options used to modify the SSH certificate. +type SSHCertificateOptionsValidator interface { + SignOption + Valid(got SSHOptions) error +} + +// SSHOptions contains the options that can be passed to the SignSSH method. +type SSHOptions struct { + CertType string `json:"certType"` + Principals []string `json:"principals"` + ValidAfter TimeDuration `json:"validAfter,omitempty"` + ValidBefore TimeDuration `json:"validBefore,omitempty"` +} + +// Type returns the uint32 representation of the CertType. +func (o SSHOptions) Type() uint32 { + return sshCertTypeUInt32(o.CertType) +} + +// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate. +func (o SSHOptions) Modify(cert *ssh.Certificate) error { + switch o.CertType { + case "": // ignore + case SSHUserCert: + cert.CertType = ssh.UserCert + case SSHHostCert: + cert.CertType = ssh.HostCert + default: + return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType) + } + cert.ValidPrincipals = o.Principals + if !o.ValidAfter.IsZero() { + cert.ValidAfter = uint64(o.ValidAfter.Time().Unix()) + } + if !o.ValidBefore.IsZero() { + cert.ValidBefore = uint64(o.ValidBefore.Time().Unix()) + } + if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { + return errors.New("ssh certificate valid after cannot be greater than valid before") + } + return nil +} + +// match compares two SSHOptions and return an error if they don't match. It +// ignores zero values. +func (o SSHOptions) match(got SSHOptions) error { + if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { + return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) + } + if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) { + return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) + } + if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { + return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) + } + if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { + return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) + } + return nil +} + +// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given +// Key ID in the SSH certificate. +type sshCertificateKeyIDModifier string + +func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error { + cert.KeyId = string(m) + return nil +} + +// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the +// certificate type to the SSH certificate. +type sshCertificateCertTypeModifier string + +func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error { + cert.CertType = sshCertTypeUInt32(string(m)) + return nil +} + +// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the +// principals to the SSH certificate. +type sshCertificatePrincipalsModifier []string + +func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error { + cert.ValidPrincipals = []string(m) + return nil +} + +// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the +// ValidAfter in the SSH certificate. +type sshCertificateValidAfterModifier uint64 + +func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error { + cert.ValidAfter = uint64(m) + return nil +} + +// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the +// ValidBefore in the SSH certificate. +type sshCertificateValidBeforeModifier uint64 + +func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error { + cert.ValidBefore = uint64(m) + return nil +} + +// sshCertificateDefaultModifier implements a SSHCertificateModifier that +// modifies the certificate with the given options if they are not set. +type sshCertificateDefaultsModifier SSHOptions + +// Modify implements the SSHCertificateModifier interface. +func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error { + if cert.CertType == 0 { + cert.CertType = sshCertTypeUInt32(m.CertType) + } + if len(cert.ValidPrincipals) == 0 { + cert.ValidPrincipals = m.Principals + } + if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() { + cert.ValidAfter = uint64(m.ValidAfter.Unix()) + } + if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() { + cert.ValidBefore = uint64(m.ValidBefore.Unix()) + } + return nil +} + +// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets +// the default extensions in an SSH certificate. +type sshDefaultExtensionModifier struct{} + +func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error { + switch cert.CertType { + // Default to no extensions for HostCert. + case ssh.HostCert: + return nil + case ssh.UserCert: + if cert.Extensions == nil { + cert.Extensions = make(map[string]string) + } + cert.Extensions["permit-X11-forwarding"] = "" + cert.Extensions["permit-agent-forwarding"] = "" + cert.Extensions["permit-port-forwarding"] = "" + cert.Extensions["permit-pty"] = "" + cert.Extensions["permit-user-rc"] = "" + return nil + default: + return errors.New("ssh certificate type has not been set or is invalid") + } +} + +// sshCertificateValidityModifier is a SSHCertificateModifier checks the +// validity bounds, setting them if they are not provided. It will fail if a +// CertType has not been set or is not valid. +type sshCertificateValidityModifier struct { + *Claimer +} + +func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error { + var d, min, max time.Duration + switch cert.CertType { + case ssh.UserCert: + d = m.DefaultUserSSHCertDuration() + min = m.MinUserSSHCertDuration() + max = m.MaxUserSSHCertDuration() + case ssh.HostCert: + d = m.DefaultHostSSHCertDuration() + min = m.MinHostSSHCertDuration() + max = m.MaxHostSSHCertDuration() + case 0: + return errors.New("ssh certificate type has not been set") + default: + return errors.Errorf("unknown ssh certificate type %d", cert.CertType) + } + + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now().Unix()) + } + if cert.ValidBefore == 0 { + t := time.Unix(int64(cert.ValidAfter), 0) + cert.ValidBefore = uint64(t.Add(d).Unix()) + } + + diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second + switch { + case diff < min: + return errors.Errorf("ssh certificate duration cannot be lower than %s", min) + case diff > max: + return errors.Errorf("ssh certificate duration cannot be greater than %s", max) + default: + return nil + } +} + +// sshCertificateOptionsValidator validates the user SSHOptions with the ones +// usually present in the token. +type sshCertificateOptionsValidator SSHOptions + +// Valid implements SSHCertificateOptionsValidator and returns nil if both +// SSHOptions match. +func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error { + want := SSHOptions(v) + return want.match(got) +} + +// sshCertificateDefaultValidator implements a simple validator for all the +// fields in the SSH certificate. +type sshCertificateDefaultValidator struct{} + +// Valid returns an error if the given certificate does not contain the necessary fields. +func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error { + switch { + case len(cert.Nonce) == 0: + return errors.New("ssh certificate nonce cannot be empty") + case cert.Key == nil: + return errors.New("ssh certificate key cannot be nil") + case cert.Serial == 0: + return errors.New("ssh certificate serial cannot be 0") + case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: + return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType) + case cert.KeyId == "": + return errors.New("ssh certificate key id cannot be empty") + case len(cert.ValidPrincipals) == 0: + return errors.New("ssh certificate valid principals cannot be empty") + case cert.ValidAfter == 0: + return errors.New("ssh certificate valid after cannot be 0") + case cert.ValidBefore == 0: + return errors.New("ssh certificate valid before cannot be 0") + case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0: + return errors.New("ssh certificate extensions cannot be empty") + case cert.SignatureKey == nil: + return errors.New("ssh certificate signature key cannot be nil") + case cert.Signature == nil: + return errors.New("ssh certificate signature cannot be nil") + default: + return nil + } +} + +// sshCertTypeUInt32 +func sshCertTypeUInt32(ct string) uint32 { + switch ct { + case SSHUserCert: + return ssh.UserCert + case SSHHostCert: + return ssh.HostCert + default: + return 0 + } +} + +// containsAllMembers reports whether all members of subgroup are within group. +func containsAllMembers(group, subgroup []string) bool { + lg, lsg := len(group), len(subgroup) + if lsg > lg || (lg > 0 && lsg == 0) { + return false + } + visit := make(map[string]struct{}, lg) + for i := 0; i < lg; i++ { + visit[group[i]] = struct{}{} + } + for i := 0; i < lsg; i++ { + if _, ok := visit[subgroup[i]]; !ok { + return false + } + } + return true +} diff --git a/authority/provisioner/ssh_test.go b/authority/provisioner/ssh_test.go new file mode 100644 index 00000000..1b31f78b --- /dev/null +++ b/authority/provisioner/ssh_test.go @@ -0,0 +1,125 @@ +package provisioner + +import ( + "crypto" + "crypto/rand" + "fmt" + "reflect" + "time" + + "golang.org/x/crypto/ssh" +) + +func validateSSHCertificate(cert *ssh.Certificate, opts *SSHOptions) error { + switch { + case cert == nil: + return fmt.Errorf("certificate is nil") + case cert.Signature == nil: + return fmt.Errorf("certificate signature is nil") + case cert.SignatureKey == nil: + return fmt.Errorf("certificate signature is nil") + case !reflect.DeepEqual(cert.ValidPrincipals, opts.Principals): + return fmt.Errorf("certificate principals are not equal, want %v, got %v", opts.Principals, cert.ValidPrincipals) + case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: + return fmt.Errorf("certificate type %v is not valid", cert.CertType) + case opts.CertType == "user" && cert.CertType != ssh.UserCert: + return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.UserCert, cert.CertType) + case opts.CertType == "host" && cert.CertType != ssh.HostCert: + return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType) + case cert.ValidAfter != uint64(opts.ValidAfter.Unix()): + return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0)) + case cert.ValidBefore != uint64(opts.ValidBefore.Unix()): + return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0)) + case opts.CertType == "user" && len(cert.Extensions) != 5: + return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions)) + case opts.CertType == "host" && len(cert.Extensions) != 0: + return fmt.Errorf("certificate extensions number is invalid, want 0, got %d", len(cert.Extensions)) + default: + return nil + } +} + +func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOption, signKey crypto.Signer) (*ssh.Certificate, error) { + pub, err := ssh.NewPublicKey(key) + if err != nil { + return nil, err + } + + var mods []SSHCertificateModifier + var validators []SSHCertificateValidator + + for _, op := range signOpts { + switch o := op.(type) { + // modify the ssh.Certificate + case SSHCertificateModifier: + mods = append(mods, o) + // modify the ssh.Certificate given the SSHOptions + case SSHCertificateOptionModifier: + mods = append(mods, o.Option(opts)) + // validate the ssh.Certificate + case SSHCertificateValidator: + validators = append(validators, o) + // validate the given SSHOptions + case SSHCertificateOptionsValidator: + if err := o.Valid(opts); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("signSSH: invalid extra option type %T", o) + } + } + + // Build base certificate with the key and some random values + cert := &ssh.Certificate{ + Nonce: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, + Key: pub, + Serial: 1234567890, + } + + // Use opts to modify the certificate + if err := opts.Modify(cert); err != nil { + return nil, err + } + + // Use provisioner modifiers + for _, m := range mods { + if err := m.Modify(cert); err != nil { + return nil, err + } + } + + // Get signer from authority keys + var signer ssh.Signer + switch cert.CertType { + case ssh.UserCert: + signer, err = ssh.NewSignerFromSigner(signKey) + case ssh.HostCert: + signer, err = ssh.NewSignerFromSigner(signKey) + default: + return nil, fmt.Errorf("unexpected ssh certificate type: %d", cert.CertType) + } + if err != nil { + return nil, err + } + cert.SignatureKey = signer.PublicKey() + + // Get bytes for signing trailing the signature length. + data := cert.Marshal() + data = data[:len(data)-4] + + // Sign the certificate + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return nil, err + } + cert.Signature = sig + + // User provisioners validators + for _, v := range validators { + if err := v.Valid(cert); err != nil { + return nil, err + } + } + + return cert, nil +} diff --git a/authority/provisioner/timeduration.go b/authority/provisioner/timeduration.go index fea967d5..7d197217 100644 --- a/authority/provisioner/timeduration.go +++ b/authority/provisioner/timeduration.go @@ -57,6 +57,17 @@ func (t *TimeDuration) SetTime(tt time.Time) { t.t, t.d = tt, 0 } +// IsZero returns true the TimeDuration represents the zero value, false +// otherwise. +func (t *TimeDuration) IsZero() bool { + return t.t.IsZero() && t.d == 0 +} + +// Equal returns if t and other are equal. +func (t *TimeDuration) Equal(other *TimeDuration) bool { + return t.t.Equal(other.t) && t.d == other.d +} + // MarshalJSON implements the json.Marshaler interface. If the time is set it // will return the time in RFC 3339 format if not it will return the duration // string. @@ -64,7 +75,7 @@ func (t TimeDuration) MarshalJSON() ([]byte, error) { switch { case t.t.IsZero(): if t.d == 0 { - return []byte("null"), nil + return []byte(`""`), nil } return json.Marshal(t.d.String()) default: @@ -102,11 +113,16 @@ func (t *TimeDuration) UnmarshalJSON(data []byte) error { return errors.Errorf("failed to parse %s", data) } -// Time calculates the embedded time.Time, sets it if necessary, and returns it. +// Time calculates the time if needed and returns it. func (t *TimeDuration) Time() time.Time { return t.RelativeTime(now()) } +// Unix calculates the time if needed it and returns the Unix time in seconds. +func (t *TimeDuration) Unix() int64 { + return t.RelativeTime(now()).Unix() +} + // RelativeTime returns the embedded time.Time or the base time plus the // duration if this is not zero. func (t *TimeDuration) RelativeTime(base time.Time) time.Time { diff --git a/authority/provisioner/timeduration_test.go b/authority/provisioner/timeduration_test.go index 97dd4ce5..65bd6f96 100644 --- a/authority/provisioner/timeduration_test.go +++ b/authority/provisioner/timeduration_test.go @@ -6,6 +6,17 @@ import ( "time" ) +func mockNow() (time.Time, func()) { + tm := time.Unix(1584198566, 535897000).UTC() + nowFn := now + now = func() time.Time { + return tm + } + return tm, func() { + now = nowFn + } +} + func TestNewTimeDuration(t *testing.T) { tm := time.Unix(1584198566, 535897000).UTC() type args struct { @@ -137,7 +148,7 @@ func TestTimeDuration_MarshalJSON(t *testing.T) { want []byte wantErr bool }{ - {"null", TimeDuration{}, []byte("null"), false}, + {"empty", TimeDuration{}, []byte(`""`), false}, {"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false}, {"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false}, {"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true}, @@ -166,7 +177,7 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) { want *TimeDuration wantErr bool }{ - {"null", args{[]byte("null")}, &TimeDuration{}, false}, + {"empty", args{[]byte(`""`)}, &TimeDuration{}, false}, {"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false}, {"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false}, {"fail", args{[]byte("123")}, &TimeDuration{}, true}, @@ -186,15 +197,8 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) { } func TestTimeDuration_Time(t *testing.T) { - nowFn := now - defer func() { - now = nowFn - now() - }() - tm := time.Unix(1584198566, 535897000).UTC() - now = func() time.Time { - return tm - } + tm, fn := mockNow() + defer fn() tests := []struct { name string timeDuration *TimeDuration @@ -211,6 +215,30 @@ func TestTimeDuration_Time(t *testing.T) { got := tt.timeDuration.Time() if !reflect.DeepEqual(got, tt.want) { t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTimeDuration_Unix(t *testing.T) { + tm, fn := mockNow() + defer fn() + tests := []struct { + name string + timeDuration *TimeDuration + want int64 + }{ + {"zero", nil, -62135596800}, + {"zero", &TimeDuration{}, -62135596800}, + {"timestamp", &TimeDuration{t: tm}, 1584198566}, + {"local", &TimeDuration{t: tm.Local()}, 1584198566}, + {"duration", &TimeDuration{d: 1 * time.Hour}, 1584202166}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.timeDuration.Unix() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TimeDuration.Unix() = %v, want %v", got, tt.want) } }) @@ -218,15 +246,8 @@ func TestTimeDuration_Time(t *testing.T) { } func TestTimeDuration_String(t *testing.T) { - nowFn := now - defer func() { - now = nowFn - now() - }() - tm := time.Unix(1584198566, 535897000).UTC() - now = func() time.Time { - return tm - } + tm, fn := mockNow() + defer fn() type fields struct { t time.Time d time.Duration diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 7871b75d..c6e820ed 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -480,6 +480,54 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { + return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{ + CertType: "user", + Principals: []string{"name"}, + }, jwk) +} + +func generateSimpleSSHHostToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { + return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{ + CertType: "host", + Principals: []string{"smallstep.com"}, + }, jwk) +} + +func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *SSHOptions, jwk *jose.JSONWebKey) (string, error) { + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + + claims := struct { + jose.Claims + Step *stepPayload `json:"step,omitempty"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + Step: &stepPayload{ + SSH: sshOpts, + }, + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, diff --git a/authority/provisioners.go b/authority/provisioners.go index 289b52a4..5328eb4d 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -13,7 +13,7 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) { key, ok := a.provisioners.LoadEncryptedKey(kid) if !ok { return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid), - http.StatusNotFound, context{}} + http.StatusNotFound, apiCtx{}} } return key, nil } @@ -31,7 +31,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi p, ok := a.provisioners.LoadByCertificate(crt) if !ok { return nil, &apiError{errors.Errorf("provisioner not found"), - http.StatusNotFound, context{}} + http.StatusNotFound, apiCtx{}} } return p, nil } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 303c4e8a..fb84a31d 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -35,7 +35,7 @@ func TestGetEncryptedKey(t *testing.T) { a: a, kid: "foo", err: &apiError{errors.Errorf("encrypted key with kid foo was not found"), - http.StatusNotFound, context{}}, + http.StatusNotFound, apiCtx{}}, } }, } diff --git a/authority/root.go b/authority/root.go index 51ed6ac5..3794a6c8 100644 --- a/authority/root.go +++ b/authority/root.go @@ -12,13 +12,13 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { val, ok := a.certificates.Load(sum) if !ok { return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum), - http.StatusNotFound, context{}} + http.StatusNotFound, apiCtx{}} } crt, ok := val.(*x509.Certificate) if !ok { return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), - http.StatusInternalServerError, context{}} + http.StatusInternalServerError, apiCtx{}} } return crt, nil } @@ -53,7 +53,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) if !ok { federation = nil err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), - http.StatusInternalServerError, context{}} + http.StatusInternalServerError, apiCtx{}} return false } federation = append(federation, crt) diff --git a/authority/root_test.go b/authority/root_test.go index d4caf71a..4b648d78 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -19,8 +19,8 @@ func TestRoot(t *testing.T) { sum string err *apiError }{ - "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}}, - "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}}, + "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}}, + "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, } diff --git a/authority/ssh.go b/authority/ssh.go new file mode 100644 index 00000000..2f69b3ca --- /dev/null +++ b/authority/ssh.go @@ -0,0 +1,239 @@ +package authority + +import ( + "crypto/rand" + "encoding/binary" + "net/http" + "strings" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/crypto/randutil" + "golang.org/x/crypto/ssh" +) + +const ( + // SSHAddUserPrincipal is the principal that will run the add user command. + // Defaults to "provisioner" but it can be changed in the configuration. + SSHAddUserPrincipal = "provisioner" + + // SSHAddUserCommand is the default command to run to add a new user. + // Defaults to "sudo useradd -m ; nc -q0 localhost 22" but it can be changed in the + // configuration. The string "" will be replace by the new + // principal to add. + SSHAddUserCommand = "sudo useradd -m ; nc -q0 localhost 22" +) + +// SignSSH creates a signed SSH certificate with the given public key and options. +func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + var mods []provisioner.SSHCertificateModifier + var validators []provisioner.SSHCertificateValidator + + for _, op := range signOpts { + switch o := op.(type) { + // modify the ssh.Certificate + case provisioner.SSHCertificateModifier: + mods = append(mods, o) + // modify the ssh.Certificate given the SSHOptions + case provisioner.SSHCertificateOptionModifier: + mods = append(mods, o.Option(opts)) + // validate the ssh.Certificate + case provisioner.SSHCertificateValidator: + validators = append(validators, o) + // validate the given SSHOptions + case provisioner.SSHCertificateOptionsValidator: + if err := o.Valid(opts); err != nil { + return nil, &apiError{err: err, code: http.StatusForbidden} + } + default: + return nil, &apiError{ + err: errors.Errorf("signSSH: invalid extra option type %T", o), + code: http.StatusInternalServerError, + } + } + } + + nonce, err := randutil.ASCII(32) + if err != nil { + return nil, &apiError{err: err, code: http.StatusInternalServerError} + } + + var serial uint64 + if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "signSSH: error reading random number"), + code: http.StatusInternalServerError, + } + } + + // Build base certificate with the key and some random values + cert := &ssh.Certificate{ + Nonce: []byte(nonce), + Key: key, + Serial: serial, + } + + // Use opts to modify the certificate + if err := opts.Modify(cert); err != nil { + return nil, &apiError{err: err, code: http.StatusForbidden} + } + + // Use provisioner modifiers + for _, m := range mods { + if err := m.Modify(cert); err != nil { + return nil, &apiError{err: err, code: http.StatusForbidden} + } + } + + // Get signer from authority keys + var signer ssh.Signer + switch cert.CertType { + case ssh.UserCert: + if a.sshCAUserCertSignKey == nil { + return nil, &apiError{ + err: errors.New("signSSH: user certificate signing is not enabled"), + code: http.StatusNotImplemented, + } + } + if signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey); err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "signSSH: error creating signer"), + code: http.StatusInternalServerError, + } + } + case ssh.HostCert: + if a.sshCAHostCertSignKey == nil { + return nil, &apiError{ + err: errors.New("signSSH: host certificate signing is not enabled"), + code: http.StatusNotImplemented, + } + } + if signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey); err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "signSSH: error creating signer"), + code: http.StatusInternalServerError, + } + } + default: + return nil, &apiError{ + err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType), + code: http.StatusInternalServerError, + } + } + cert.SignatureKey = signer.PublicKey() + + // Get bytes for signing trailing the signature length. + data := cert.Marshal() + data = data[:len(data)-4] + + // Sign the certificate + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "signSSH: error signing certificate"), + code: http.StatusInternalServerError, + } + } + cert.Signature = sig + + // User provisioners validators + for _, v := range validators { + if err := v.Valid(cert); err != nil { + return nil, &apiError{err: err, code: http.StatusForbidden} + } + } + + return cert, nil +} + +// SignSSHAddUser signs a certificate that provisions a new user in a server. +func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { + if a.sshCAUserCertSignKey == nil { + return nil, &apiError{ + err: errors.New("signSSHAddUser: user certificate signing is not enabled"), + code: http.StatusNotImplemented, + } + } + if subject.CertType != ssh.UserCert { + return nil, &apiError{ + err: errors.New("signSSHProxy: certificate is not a user certificate"), + code: http.StatusForbidden, + } + } + if len(subject.ValidPrincipals) != 1 { + return nil, &apiError{ + err: errors.New("signSSHProxy: certificate does not have only one principal"), + code: http.StatusForbidden, + } + } + + nonce, err := randutil.ASCII(32) + if err != nil { + return nil, &apiError{err: err, code: http.StatusInternalServerError} + } + + var serial uint64 + if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "signSSHProxy: error reading random number"), + code: http.StatusInternalServerError, + } + } + + signer, err := ssh.NewSignerFromSigner(a.sshCAUserCertSignKey) + if err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "signSSHProxy: error creating signer"), + code: http.StatusInternalServerError, + } + } + + principal := subject.ValidPrincipals[0] + addUserPrincipal := a.getAddUserPrincipal() + + cert := &ssh.Certificate{ + Nonce: []byte(nonce), + Key: key, + Serial: serial, + CertType: ssh.UserCert, + KeyId: principal + "-" + addUserPrincipal, + ValidPrincipals: []string{addUserPrincipal}, + ValidAfter: subject.ValidAfter, + ValidBefore: subject.ValidBefore, + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{ + "force-command": a.getAddUserCommand(principal), + }, + }, + SignatureKey: signer.PublicKey(), + } + + // Get bytes for signing trailing the signature length. + data := cert.Marshal() + data = data[:len(data)-4] + + // Sign the certificate + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return nil, err + } + cert.Signature = sig + return cert, nil +} + +func (a *Authority) getAddUserPrincipal() (cmd string) { + if a.config.SSH.AddUserPrincipal == "" { + return SSHAddUserPrincipal + } + return a.config.SSH.AddUserPrincipal +} + +func (a *Authority) getAddUserCommand(principal string) string { + var cmd string + if a.config.SSH.AddUserCommand == "" { + cmd = SSHAddUserCommand + } else { + cmd = a.config.SSH.AddUserCommand + } + return strings.Replace(cmd, "", principal, -1) +} diff --git a/authority/ssh_test.go b/authority/ssh_test.go new file mode 100644 index 00000000..37a9a8f7 --- /dev/null +++ b/authority/ssh_test.go @@ -0,0 +1,252 @@ +package authority + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "golang.org/x/crypto/ssh" +) + +type sshTestModifier ssh.Certificate + +func (m sshTestModifier) Modify(cert *ssh.Certificate) error { + if m.CertType != 0 { + cert.CertType = m.CertType + } + if m.KeyId != "" { + cert.KeyId = m.KeyId + } + if m.ValidAfter != 0 { + cert.ValidAfter = m.ValidAfter + } + if m.ValidBefore != 0 { + cert.ValidBefore = m.ValidBefore + } + if len(m.ValidPrincipals) != 0 { + cert.ValidPrincipals = m.ValidPrincipals + } + if m.Permissions.CriticalOptions != nil { + cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions + } + if m.Permissions.Extensions != nil { + cert.Permissions.Extensions = m.Permissions.Extensions + } + return nil +} + +type sshTestCertModifier string + +func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error { + if m == "" { + return nil + } + return fmt.Errorf(string(m)) +} + +type sshTestCertValidator string + +func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error { + if v == "" { + return nil + } + return fmt.Errorf(string(v)) +} + +type sshTestOptionsValidator string + +func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error { + if v == "" { + return nil + } + return fmt.Errorf(string(v)) +} + +type sshTestOptionsModifier string + +func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier { + return sshTestCertModifier(string(m)) +} + +func TestAuthority_SignSSH(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + pub, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + + userOptions := sshTestModifier{ + CertType: ssh.UserCert, + } + hostOptions := sshTestModifier{ + CertType: ssh.HostCert, + } + + now := time.Now() + + type fields struct { + sshCAUserCertSignKey crypto.Signer + sshCAHostCertSignKey crypto.Signer + } + type args struct { + key ssh.PublicKey + opts provisioner.SSHOptions + signOpts []provisioner.SignOption + } + type want struct { + CertType uint32 + Principals []string + ValidAfter uint64 + ValidBefore uint64 + } + tests := []struct { + name string + fields fields + args args + want want + wantErr bool + }{ + {"ok-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions}}, want{CertType: ssh.UserCert}, false}, + {"ok-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{hostOptions}}, want{CertType: ssh.HostCert}, false}, + {"ok-opts-type-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-type-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert}, false}, + {"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, + {"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, + {"ok-opts-valid-after", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, + {"ok-opts-valid-before", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, + {"ok-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, + {"fail-opts-type", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "foo"}, []provisioner.SignOption{}}, want{}, true}, + {"fail-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("an error")}}, want{}, true}, + {"fail-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("an error")}}, want{}, true}, + {"fail-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, + {"fail-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, + {"fail-bad-sign-options", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, "wrong type"}}, want{}, true}, + {"fail-no-user-key", fields{nil, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{}, true}, + {"fail-no-host-key", fields{signKey, nil}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{}, true}, + {"fail-bad-type", fields{signKey, nil}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{sshTestModifier{CertType: 0}}}, want{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey + a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey + + got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil && assert.NotNil(t, got) { + assert.Equals(t, tt.want.CertType, got.CertType) + assert.Equals(t, tt.want.Principals, got.ValidPrincipals) + assert.Equals(t, tt.want.ValidAfter, got.ValidAfter) + assert.Equals(t, tt.want.ValidBefore, got.ValidBefore) + assert.NotNil(t, got.Key) + assert.NotNil(t, got.Nonce) + assert.NotEquals(t, 0, got.Serial) + assert.NotNil(t, got.Signature) + assert.NotNil(t, got.SignatureKey) + } + }) + } +} + +func TestAuthority_SignSSHAddUser(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + pub, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + + type fields struct { + sshCAUserCertSignKey crypto.Signer + sshCAHostCertSignKey crypto.Signer + addUserPrincipal string + addUserCommand string + } + type args struct { + key ssh.PublicKey + subject *ssh.Certificate + } + type want struct { + CertType uint32 + Principals []string + ValidAfter uint64 + ValidBefore uint64 + ForceCommand string + } + + now := time.Now() + validCert := &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{"user"}, + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + } + validWant := want{ + CertType: ssh.UserCert, + Principals: []string{"provisioner"}, + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + ForceCommand: "sudo useradd -m user; nc -q0 localhost 22", + } + + tests := []struct { + name string + fields fields + args args + want want + wantErr bool + }{ + {"ok", fields{signKey, signKey, "", ""}, args{pub, validCert}, validWant, false}, + {"ok-no-host-key", fields{signKey, nil, "", ""}, args{pub, validCert}, validWant, false}, + {"ok-custom-principal", fields{signKey, signKey, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false}, + {"ok-custom-command", fields{signKey, signKey, "", "foo "}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false}, + {"ok-custom-principal-and-command", fields{signKey, signKey, "my-principal", "foo "}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false}, + {"fail-no-user-key", fields{nil, signKey, "", ""}, args{pub, validCert}, want{}, true}, + {"fail-no-user-cert", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true}, + {"fail-no-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true}, + {"fail-many-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey + a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey + a.config.SSH = &SSHConfig{ + AddUserPrincipal: tt.fields.addUserPrincipal, + AddUserCommand: tt.fields.addUserCommand, + } + got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil && assert.NotNil(t, got) { + assert.Equals(t, tt.want.CertType, got.CertType) + assert.Equals(t, tt.want.Principals, got.ValidPrincipals) + assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId) + assert.Equals(t, tt.want.ValidAfter, got.ValidAfter) + assert.Equals(t, tt.want.ValidBefore, got.ValidBefore) + assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions) + assert.Equals(t, nil, got.Extensions) + assert.NotNil(t, got.Key) + assert.NotNil(t, got.Nonce) + assert.NotEquals(t, 0, got.Serial) + assert.NotNil(t, got.Signature) + assert.NotNil(t, got.SignatureKey) + } + }) + } +} diff --git a/authority/tls.go b/authority/tls.go index fdaba130..d54e4373 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -58,7 +58,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption { // Sign creates a signed certificate from a certificate signing request. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) { var ( - errContext = context{"csr": csr, "signOptions": signOpts} + errContext = apiCtx{"csr": csr, "signOptions": signOpts} mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} certValidators = []provisioner.CertificateValidator{} issIdentity = a.intermediateIdentity @@ -181,23 +181,23 @@ func (a *Authority) Renew(oldCert *x509.Certificate) (*x509.Certificate, *x509.C leaf, err := x509util.NewLeafProfileWithTemplate(newCert, issIdentity.Crt, issIdentity.Key) if err != nil { - return nil, nil, &apiError{err, http.StatusInternalServerError, context{}} + return nil, nil, &apiError{err, http.StatusInternalServerError, apiCtx{}} } crtBytes, err := leaf.CreateCertificate() if err != nil { return nil, nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"), - http.StatusInternalServerError, context{}} + http.StatusInternalServerError, apiCtx{}} } serverCert, err := x509.ParseCertificate(crtBytes) if err != nil { return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"), - http.StatusInternalServerError, context{}} + http.StatusInternalServerError, apiCtx{}} } caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw) if err != nil { return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"), - http.StatusInternalServerError, context{}} + http.StatusInternalServerError, apiCtx{}} } return serverCert, caCert, nil @@ -222,7 +222,7 @@ type RevokeOptions struct { // // TODO: Add OCSP and CRL support. func (a *Authority) Revoke(opts *RevokeOptions) error { - errContext := context{ + errContext := apiCtx{ "serialNumber": opts.Serial, "reasonCode": opts.ReasonCode, "reason": opts.Reason, diff --git a/authority/tls_test.go b/authority/tls_test.go index 5e3a3746..8d443fd4 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto/rand" "crypto/sha1" "crypto/x509" @@ -103,7 +104,8 @@ func TestSign(t *testing.T) { assert.FatalError(t, err) token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key) assert.FatalError(t, err) - extraOpts, err := a.Authorize(token) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + extraOpts, err := a.Authorize(ctx, token) assert.FatalError(t, err) type signTest struct { @@ -124,7 +126,7 @@ func TestSign(t *testing.T) { signOpts: signOpts, err: &apiError{errors.New("sign: invalid certificate request"), http.StatusBadRequest, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -138,7 +140,7 @@ func TestSign(t *testing.T) { signOpts: signOpts, err: &apiError{errors.New("sign: invalid extra option type string"), http.StatusInternalServerError, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -153,7 +155,7 @@ func TestSign(t *testing.T) { signOpts: signOpts, err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"), http.StatusInternalServerError, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -168,7 +170,7 @@ func TestSign(t *testing.T) { signOpts: signOpts, err: &apiError{errors.New("sign: error creating new leaf certificate"), http.StatusInternalServerError, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -185,7 +187,7 @@ func TestSign(t *testing.T) { signOpts: _signOpts, err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), http.StatusUnauthorized, - context{"csr": csr, "signOptions": _signOpts}, + apiCtx{"csr": csr, "signOptions": _signOpts}, }, } }, @@ -200,7 +202,7 @@ func TestSign(t *testing.T) { signOpts: signOpts, err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), http.StatusUnauthorized, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -227,7 +229,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG signOpts: signOpts, err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"), http.StatusUnauthorized, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -238,7 +240,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG storeCertificate: func(crt *x509.Certificate) error { return &apiError{errors.New("force"), http.StatusInternalServerError, - context{"csr": csr, "signOptions": signOpts}} + apiCtx{"csr": csr, "signOptions": signOpts}} }, } return &signTest{ @@ -248,7 +250,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG signOpts: signOpts, err: &apiError{errors.New("sign: error storing certificate in db: force"), http.StatusInternalServerError, - context{"csr": csr, "signOptions": signOpts}, + apiCtx{"csr": csr, "signOptions": signOpts}, }, } }, @@ -401,7 +403,7 @@ func TestRenew(t *testing.T) { auth: _a, crt: crt, err: &apiError{errors.New("error renewing certificate from existing server certificate"), - http.StatusInternalServerError, context{}}, + http.StatusInternalServerError, apiCtx{}}, }, nil }, "fail-unauthorized": func() (*renewTest, error) { @@ -596,7 +598,7 @@ func TestRevoke(t *testing.T) { validAudience := []string{"https://test.ca.smallstep.com/revoke"} now := time.Now().UTC() getCtx := func() map[string]interface{} { - return context{ + return apiCtx{ "serialNumber": "sn", "reasonCode": reasonCode, "reason": reason, diff --git a/ca/client.go b/ca/client.go index 6ac16e8a..826bee7f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -373,6 +373,28 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { return &sign, nil } +// SignSSH performs the SSH certificate sign request to the CA and returns the +// api.SignSSHResponse struct. +func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"}) + resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var sign api.SignSSHResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &sign, nil +} + // Renew performs the renew request to the CA and returns the api.SignResponse // struct. func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { diff --git a/cmd/step-ca/main.go b/cmd/step-ca/main.go index fbb0f2b4..748d1b64 100644 --- a/cmd/step-ca/main.go +++ b/cmd/step-ca/main.go @@ -280,7 +280,11 @@ func stringifyFlag(f cli.Flag) string { usage := fv.FieldByName("Usage").String() placeholder := placeholderString.FindString(usage) if placeholder == "" { - placeholder = "" + switch f.(type) { + case cli.BoolFlag, cli.BoolTFlag: + default: + placeholder = "" + } } return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage }