Add /revoke API with interface db backend

pull/50/head
max furman 5 years ago
parent 07ff7d9807
commit ab4d569f36

27
Gopkg.lock generated

@ -228,12 +228,12 @@
version = "v1.2.0" version = "v1.2.0"
[[projects]] [[projects]]
digest = "1:40e195917a951a8bf867cd05de2a46aaf1806c50cf92eebf4c16f78cd196f747" digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b"
name = "github.com/pkg/errors" name = "github.com/pkg/errors"
packages = ["."] packages = ["."]
pruneopts = "UT" pruneopts = "UT"
revision = "645ef00459ed84a119197bfb8d8205042c6df63d" revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4"
version = "v0.8.0" version = "v0.8.1"
[[projects]] [[projects]]
digest = "1:2e76a73cb51f42d63a2a1a85b3dc5731fd4faf6821b434bd0ef2c099186031d6" digest = "1:2e76a73cb51f42d63a2a1a85b3dc5731fd4faf6821b434bd0ef2c099186031d6"
@ -300,6 +300,14 @@
pruneopts = "UT" pruneopts = "UT"
revision = "f851b6b63d8d5e78b8a986057034d69fe904c477" revision = "f851b6b63d8d5e78b8a986057034d69fe904c477"
[[projects]]
branch = "master"
digest = "1:fd8d9eb07509d8ef47fc82c99646f0b2203b2ba3c240ba77d8c457bb6109836d"
name = "github.com/smallstep/nosql"
packages = ["."]
pruneopts = "UT"
revision = "d8f68d14f9ae04e0991dce06b44768f2d38dccf8"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:ba52e5a5fb800ce55108b7a5f181bb809aab71c16736051312b0aa969f82ad39" digest = "1:ba52e5a5fb800ce55108b7a5f181bb809aab71c16736051312b0aa969f82ad39"
@ -316,15 +324,24 @@
pruneopts = "UT" pruneopts = "UT"
revision = "b67dcf995b6a7b7f14fad5fcb7cc5441b05e814b" revision = "b67dcf995b6a7b7f14fad5fcb7cc5441b05e814b"
[[projects]]
digest = "1:5f7414cf41466d4b4dd7ec52b2cd3e481e08cfd11e7e24fef730c0e483e88bb1"
name = "go.etcd.io/bbolt"
packages = ["."]
pruneopts = "UT"
revision = "63597a96ec0ad9e6d43c3fc81e809909e0237461"
version = "v1.3.2"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:a068d4e48e0f2e172903d25b6e066815fa8efd4b01102aec4c741f02a9650c03" digest = "1:5dd7da6df07f42194cb25d162b4b89664ed7b08d7d4334f6a288393d54b095ce"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
packages = [ packages = [
"cryptobyte", "cryptobyte",
"cryptobyte/asn1", "cryptobyte/asn1",
"ed25519", "ed25519",
"ed25519/internal/edwards25519", "ed25519/internal/edwards25519",
"ocsp",
"pbkdf2", "pbkdf2",
"ssh/terminal", "ssh/terminal",
] ]
@ -574,8 +591,10 @@
"github.com/smallstep/cli/token", "github.com/smallstep/cli/token",
"github.com/smallstep/cli/token/provision", "github.com/smallstep/cli/token/provision",
"github.com/smallstep/cli/usage", "github.com/smallstep/cli/usage",
"github.com/smallstep/nosql",
"github.com/tsenart/deadcode", "github.com/tsenart/deadcode",
"github.com/urfave/cli", "github.com/urfave/cli",
"golang.org/x/crypto/ocsp",
"golang.org/x/net/context", "golang.org/x/net/context",
"golang.org/x/net/http2", "golang.org/x/net/http2",
"google.golang.org/grpc", "google.golang.org/grpc",

@ -48,6 +48,10 @@ required = [
branch = "master" branch = "master"
name = "github.com/smallstep/cli" name = "github.com/smallstep/cli"
[[constraint]]
branch = "master"
name = "github.com/smallstep/nosql"
[prune] [prune]
go-tests = true go-tests = true
unused-packages = true unused-packages = true

@ -18,6 +18,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
@ -25,12 +26,17 @@ import (
// Authority is the interface implemented by a CA authority. // Authority is the interface implemented by a CA authority.
type Authority interface { type Authority interface {
// NOTE: Authorize will be deprecated in future releases. Please use the
// context specific Authoirize[Sign|Revoke|etc.] methods.
Authorize(ott string) ([]provisioner.SignOption, error) Authorize(ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
GetTLSOptions() *tlsutil.TLSOptions GetTLSOptions() *tlsutil.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
GetProvisioners(cursor string, limit int) (provisioner.List, string, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
Revoke(*authority.RevokeOptions) error
GetEncryptedKey(kid string) (string, error) GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error) GetRoots() (federation []*x509.Certificate, err error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
@ -236,6 +242,7 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("GET", "/root/{sha}", h.Root) r.MethodFunc("GET", "/root/{sha}", h.Root)
r.MethodFunc("POST", "/sign", h.Sign) r.MethodFunc("POST", "/sign", h.Sign)
r.MethodFunc("POST", "/renew", h.Renew) r.MethodFunc("POST", "/renew", h.Renew)
r.MethodFunc("POST", "/revoke", h.Revoke)
r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners", h.Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/roots", h.Roots)
@ -285,7 +292,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
NotAfter: body.NotAfter, NotAfter: body.NotAfter,
} }
signOpts, err := h.Authority.Authorize(body.OTT) signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil { if err != nil {
WriteError(w, Unauthorized(err)) WriteError(w, Unauthorized(err))
return return

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
@ -407,23 +408,110 @@ func TestSignRequest_Validate(t *testing.T) {
} }
} }
type mockProvisioner struct {
ret1, ret2, ret3 interface{}
err error
getID func() string
getTokenID func(string) (string, error)
getName func() string
getType func() provisioner.Type
getEncryptedKey func() (string, string, bool)
init func(provisioner.Config) error
authorizeRevoke func(ott string) error
authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorizeRenewal func(*x509.Certificate) error
}
func (m *mockProvisioner) GetID() string {
if m.getID != nil {
return m.getID()
}
return m.ret1.(string)
}
func (m *mockProvisioner) GetTokenID(token string) (string, error) {
if m.getTokenID != nil {
return m.getTokenID(token)
}
if m.ret1 == nil {
return "", m.err
}
return m.ret1.(string), m.err
}
func (m *mockProvisioner) GetName() string {
if m.getName != nil {
return m.getName()
}
return m.ret1.(string)
}
func (m *mockProvisioner) GetType() provisioner.Type {
if m.getType != nil {
return m.getType()
}
return m.ret1.(provisioner.Type)
}
func (m *mockProvisioner) GetEncryptedKey() (string, string, bool) {
if m.getEncryptedKey != nil {
return m.getEncryptedKey()
}
return m.ret1.(string), m.ret2.(string), m.ret3.(bool)
}
func (m *mockProvisioner) Init(c provisioner.Config) error {
if m.init != nil {
return m.init(c)
}
return m.err
}
func (m *mockProvisioner) AuthorizeRevoke(ott string) error {
if m.authorizeRevoke != nil {
return m.authorizeRevoke(ott)
}
return m.err
}
func (m *mockProvisioner) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
if m.authorizeSign != nil {
return m.authorizeSign(ott)
}
return m.ret1.([]provisioner.SignOption), m.err
}
func (m *mockProvisioner) AuthorizeRenewal(c *x509.Certificate) error {
if m.authorizeRenewal != nil {
return m.authorizeRenewal(c)
}
return m.err
}
type mockAuthority struct { type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
authorize func(ott string) ([]provisioner.SignOption, error) authorizeSign func(ott string) ([]provisioner.SignOption, error)
getTLSOptions func() *tlsutil.TLSOptions getTLSOptions func() *tlsutil.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
getEncryptedKey func(kid string) (string, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
getRoots func() ([]*x509.Certificate, error) revoke func(*authority.RevokeOptions) error
getFederation func() ([]*x509.Certificate, error) getEncryptedKey func(kid string) (string, error)
getRoots func() ([]*x509.Certificate, error)
getFederation func() ([]*x509.Certificate, error)
} }
// TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) { func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
if m.authorize != nil { return m.AuthorizeSign(ott)
return m.authorize(ott) }
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
if m.authorizeSign != nil {
return m.authorizeSign(ott)
} }
return m.ret1.([]provisioner.SignOption), m.err return m.ret1.([]provisioner.SignOption), m.err
} }
@ -463,6 +551,20 @@ func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provision
return m.ret1.(provisioner.List), m.ret2.(string), m.err return m.ret1.(provisioner.List), m.ret2.(string), m.err
} }
func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) {
if m.loadProvisionerByCertificate != nil {
return m.loadProvisionerByCertificate(cert)
}
return m.ret1.(provisioner.Interface), m.err
}
func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error {
if m.revoke != nil {
return m.revoke(opts)
}
return m.err
}
func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
if m.getEncryptedKey != nil { if m.getEncryptedKey != nil {
return m.getEncryptedKey(kid) return m.getEncryptedKey(kid)
@ -617,7 +719,7 @@ func Test_caHandler_Sign(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr, ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorize: func(ott string) ([]provisioner.SignOption, error) { authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
return tt.certAttrOpts, tt.autherr return tt.certAttrOpts, tt.autherr
}, },
getTLSOptions: func() *tlsutil.TLSOptions { getTLSOptions: func() *tlsutil.TLSOptions {

@ -82,6 +82,11 @@ func InternalServerError(err error) error {
return NewError(http.StatusInternalServerError, err) return NewError(http.StatusInternalServerError, err)
} }
// NotImplemented returns a 500 error with the given error.
func NotImplemented(err error) error {
return NewError(http.StatusNotImplemented, err)
}
// BadRequest returns an 400 error with the given error. // BadRequest returns an 400 error with the given error.
func BadRequest(err error) error { func BadRequest(err error) error {
return NewError(http.StatusBadRequest, err) return NewError(http.StatusBadRequest, err)

@ -0,0 +1,105 @@
package api
import (
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
)
// RevokeResponse is the response object that returns the health of the server.
type RevokeResponse struct {
Status string `json:"status"`
}
// RevokeRequest is the request body for a revocation request.
type RevokeRequest struct {
Serial string `json:"serial"`
OTT string `json:"ott"`
ReasonCode int `json:"reasonCode"`
Reason string `json:"reason"`
Passive bool `json:"passive"`
}
// Validate checks the fields of the RevokeRequest and returns nil if they are ok
// or an error if something is wrong.
func (r *RevokeRequest) Validate() (err error) {
if r.Serial == "" {
return BadRequest(errors.New("missing serial"))
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return BadRequest(errors.New("reasonCode out of bounds"))
}
if !r.Passive {
return NotImplemented(errors.New("non-passive revocation not implemented"))
}
return
}
// Revoke supports handful of different methods that revoke a Certificate.
//
// NOTE: currently only Passive revocation is supported.
//
// TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
return
}
opts := &authority.RevokeOptions{
Serial: body.Serial,
Reason: body.Reason,
ReasonCode: body.ReasonCode,
PassiveOnly: body.Passive,
}
// A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS.
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
opts.OTT = body.OTT
} else {
// If no token is present, then the request must be made over mTLS and
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, BadRequest(errors.New("missing ott or peer certificate")))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
logCertificate(w, opts.Crt)
opts.MTLS = true
}
if err := h.Authority.Revoke(opts); err != nil {
WriteError(w, Forbidden(err))
return
}
logRevoke(w, opts)
w.WriteHeader(http.StatusOK)
JSON(w, &RevokeResponse{Status: "ok"})
}
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"serial": ri.Serial,
"reasonCode": ri.ReasonCode,
"reason": ri.Reason,
"passiveOnly": ri.PassiveOnly,
"mTLS": ri.MTLS,
})
}
}

@ -0,0 +1,234 @@
package api
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
)
func TestRevokeRequestValidate(t *testing.T) {
type test struct {
rr *RevokeRequest
err *Error
}
tests := map[string]test{
"error/missing serial": {
rr: &RevokeRequest{},
err: &Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
},
"error/bad reasonCode": {
rr: &RevokeRequest{
Serial: "sn",
ReasonCode: 15,
Passive: true,
},
err: &Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
},
"error/non-passive not implemented": {
rr: &RevokeRequest{
Serial: "sn",
ReasonCode: 8,
Passive: false,
},
err: &Error{Err: errors.New("non-passive revocation not implemented"), Status: http.StatusNotImplemented},
},
"ok": {
rr: &RevokeRequest{
Serial: "sn",
ReasonCode: 9,
Passive: true,
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if err := tc.rr.Validate(); err != nil {
switch v := err.(type) {
case *Error:
assert.HasPrefix(t, v.Error(), tc.err.Error())
assert.Equals(t, v.StatusCode(), tc.err.Status)
default:
t.Errorf("unexpected error type: %T", v)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func Test_caHandler_Revoke(t *testing.T) {
type test struct {
input string
auth Authority
tls *tls.ConnectionState
err error
statusCode int
expected []byte
}
tests := map[string]func(*testing.T) test{
"400/json read error": func(t *testing.T) test {
return test{
input: "{",
statusCode: http.StatusBadRequest,
}
},
"400/invalid request body": func(t *testing.T) test {
input, err := json.Marshal(RevokeRequest{})
assert.FatalError(t, err)
return test{
input: string(input),
statusCode: http.StatusBadRequest,
}
},
"200/ott": func(t *testing.T) test {
input, err := json.Marshal(RevokeRequest{
Serial: "sn",
ReasonCode: 4,
Reason: "foo",
OTT: "valid",
Passive: true,
})
assert.FatalError(t, err)
return test{
input: string(input),
statusCode: http.StatusOK,
auth: &mockAuthority{
revoke: func(opts *authority.RevokeOptions) error {
assert.True(t, opts.PassiveOnly)
assert.False(t, opts.MTLS)
assert.Equals(t, opts.Serial, "sn")
assert.Equals(t, opts.ReasonCode, 4)
assert.Equals(t, opts.Reason, "foo")
return nil
},
},
expected: []byte(`{"status":"ok"}`),
}
},
"400/no OTT and no peer certificate": func(t *testing.T) test {
input, err := json.Marshal(RevokeRequest{
Serial: "sn",
ReasonCode: 4,
Passive: true,
})
assert.FatalError(t, err)
return test{
input: string(input),
statusCode: http.StatusBadRequest,
}
},
"200/no ott": func(t *testing.T) test {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
input, err := json.Marshal(RevokeRequest{
Serial: "1404354960355712309",
ReasonCode: 4,
Reason: "foo",
Passive: true,
})
assert.FatalError(t, err)
return test{
input: string(input),
statusCode: http.StatusOK,
tls: cs,
auth: &mockAuthority{
revoke: func(ri *authority.RevokeOptions) error {
assert.True(t, ri.PassiveOnly)
assert.True(t, ri.MTLS)
assert.Equals(t, ri.Serial, "1404354960355712309")
assert.Equals(t, ri.ReasonCode, 4)
assert.Equals(t, ri.Reason, "foo")
return nil
},
loadProvisionerByCertificate: func(crt *x509.Certificate) (provisioner.Interface, error) {
return &mockProvisioner{
getID: func() string {
return "mock-provisioner-id"
},
}, err
},
},
expected: []byte(`{"status":"ok"}`),
}
},
"500/ott authority.Revoke": func(t *testing.T) test {
input, err := json.Marshal(RevokeRequest{
Serial: "sn",
ReasonCode: 4,
Reason: "foo",
OTT: "valid",
Passive: true,
})
assert.FatalError(t, err)
return test{
input: string(input),
statusCode: http.StatusInternalServerError,
auth: &mockAuthority{
revoke: func(opts *authority.RevokeOptions) error {
return InternalServerError(errors.New("force"))
},
},
}
},
"403/ott authority.Revoke": func(t *testing.T) test {
input, err := json.Marshal(RevokeRequest{
Serial: "sn",
ReasonCode: 4,
Reason: "foo",
OTT: "valid",
Passive: true,
})
assert.FatalError(t, err)
return test{
input: string(input),
statusCode: http.StatusForbidden,
auth: &mockAuthority{
revoke: func(opts *authority.RevokeOptions) error {
return errors.New("force")
},
},
}
},
}
for name, _tc := range tests {
tc := _tc(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
if tc.tls != nil {
req.TLS = tc.tls
}
w := httptest.NewRecorder()
h.Revoke(logging.NewResponseLogger(w), req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if tc.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), tc.expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, tc.expected)
}
}
})
}
}

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
) )
@ -24,6 +25,7 @@ type Authority struct {
ottMap *sync.Map ottMap *sync.Map
startTime time.Time startTime time.Time
provisioners *provisioner.Collection provisioners *provisioner.Collection
db db.AuthDB
// Do not re-initialize // Do not re-initialize
initOnce bool initOnce bool
} }
@ -56,6 +58,12 @@ func (a *Authority) init() error {
var err error var err error
// Initialize step-ca Database if defined in configuration.
// If a.config.DB is nil then a noopDB will be returned.
if a.db, err = db.New(a.config.DB); err != nil {
return err
}
// Load the root certificates and add them to the certificate store // Load the root certificates and add them to the certificate store
a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root)) a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root))
for i, path := range a.config.Root { for i, path := range a.config.Root {
@ -111,3 +119,8 @@ func (a *Authority) init() error {
return nil return nil
} }
// Shutdown safely shuts down any clients, databases, etc. held by the Authority.
func (a *Authority) Shutdown() error {
return a.db.Shutdown()
}

@ -36,11 +36,19 @@ func testAuthority(t *testing.T) *Authority {
DisableRenewal: &disableRenewal, DisableRenewal: &disableRenewal,
}, },
}, },
&provisioner.JWK{
Name: "renew_disabled",
Type: "JWK",
Key: maxjwk,
Claims: &provisioner.Claims{
DisableRenewal: &disableRenewal,
},
},
} }
c := &Config{ c := &Config{
Address: "127.0.0.1:443", Address: "127.0.0.1:443",
Root: []string{"testdata/secrets/root_ca.crt"}, Root: []string{"testdata/certs/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateCert: "testdata/certs/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key", IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.ca.smallstep.com"}, DNSNames: []string{"test.ca.smallstep.com"},
Password: "pass", Password: "pass",

@ -24,15 +24,16 @@ type Claims struct {
Nonce string `json:"nonce,omitempty"` Nonce string `json:"nonce,omitempty"`
} }
// Authorize authorizes a signature request by validating and authenticating // authorizeToken parses the token and returns the provisioner used to generate
// a OTT that must be sent w/ the request. // the token. This method enforces the One-Time use policy (tokens can only be
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) { // used once).
func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
var errContext = map[string]interface{}{"ott": ott} var errContext = map[string]interface{}{"ott": ott}
// Validate payload // Validate payload
token, err := jose.ParseSigned(ott) token, err := jose.ParseSigned(ott)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrapf(err, "authorize: error parsing token"), return nil, &apiError{errors.Wrapf(err, "authorizeToken: error parsing token"),
http.StatusUnauthorized, errContext} http.StatusUnauthorized, errContext}
} }
@ -41,9 +42,10 @@ func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
// before we can look up the provisioner. // before we can look up the provisioner.
var claims Claims var claims Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, &apiError{err, http.StatusUnauthorized, errContext} return nil, &apiError{errors.Wrap(err, "authorizeToken"), http.StatusUnauthorized, errContext}
} }
// TODO: use new persistence layer abstraction.
// Do not accept tokens issued before the start of the ca. // Do not accept tokens issued before the start of the ca.
// This check is meant as a stopgap solution to the current lack of a persistence layer. // This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
@ -57,7 +59,7 @@ func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
p, ok := a.provisioners.LoadByToken(token, &claims.Claims) p, ok := a.provisioners.LoadByToken(token, &claims.Claims)
if !ok { if !ok {
return nil, &apiError{ return nil, &apiError{
errors.Errorf("authorize: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")), errors.Errorf("authorizeToken: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")),
http.StatusUnauthorized, errContext} http.StatusUnauthorized, errContext}
} }
@ -74,19 +76,69 @@ func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
UsedAt: time.Now().Unix(), UsedAt: time.Now().Unix(),
Subject: claims.Subject, Subject: claims.Subject,
}); ok { }); ok {
return nil, &apiError{errors.Errorf("authorize: token already used"), http.StatusUnauthorized, errContext} return nil, &apiError{errors.Errorf("authorizeToken: token already used"), http.StatusUnauthorized, errContext}
} }
} }
// Call the provisioner Authorize method to get the signing options return p, nil
opts, err := p.Authorize(ott) }
// Authorize is a passthrough to AuthorizeSign.
// NOTE: Authorize will be deprecated in a future release. Please use the
// context specific Authorize[Sign|Revoke|etc.] going forwards.
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
return a.AuthorizeSign(ott)
}
// AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
var errContext = context{"ott": ott}
p, err := a.authorizeToken(ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
}
// Call the provisioner AuthorizeSign method to apply provisioner specific
// auth claims and get the signing options.
opts, err := p.AuthorizeSign(ott)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "authorize"), http.StatusUnauthorized, errContext} return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
} }
return opts, nil return opts, nil
} }
// authorizeRevoke authorizes a revocation request by validating and authenticating
// the RevokeOptions POSTed with the request.
// Returns a tuple of the provisioner ID and error, if one occurred.
func (a *Authority) authorizeRevoke(opts *RevokeOptions) (p provisioner.Interface, err error) {
if opts.MTLS {
if opts.Crt.SerialNumber.String() != opts.Serial {
return nil, errors.New("authorizeRevoke: serial number in certificate different than body")
}
// Load the Certificate provisioner if one exists.
p, err = a.LoadProvisionerByCertificate(opts.Crt)
if err != nil {
return nil, errors.Wrap(err, "authorizeRevoke")
}
} else {
// Gets the token provisioner and validates common token fields.
p, err = a.authorizeToken(opts.OTT)
if err != nil {
return nil, errors.Wrap(err, "authorizeRevoke")
}
// Call the provisioner AuthorizeRevoke to apply provisioner specific auth claims.
err = p.AuthorizeRevoke(opts.OTT)
if err != nil {
return nil, errors.Wrap(err, "authorizeRevoke")
}
}
return
}
// authorizeRenewal tries to locate the step provisioner extension, and checks // authorizeRenewal tries to locate the step provisioner extension, and checks
// if for the configured provisioner, the renewal is enabled or not. If the // if for the configured provisioner, the renewal is enabled or not. If the
// extra extension cannot be found, authorize the renewal by default. // extra extension cannot be found, authorize the renewal by default.
@ -94,17 +146,35 @@ func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
// TODO(mariano): should we authorize by default? // TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenewal(crt *x509.Certificate) error { func (a *Authority) authorizeRenewal(crt *x509.Certificate) error {
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()} errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
// Check the passive revocation table.
isRevoked, err := a.db.IsRevoked(crt.SerialNumber.String())
if err != nil {
return &apiError{
err: errors.Wrap(err, "renew"),
code: http.StatusInternalServerError,
context: errContext,
}
}
if isRevoked {
return &apiError{
err: errors.New("renew: certificate has been revoked"),
code: http.StatusUnauthorized,
context: errContext,
}
}
p, ok := a.provisioners.LoadByCertificate(crt) p, ok := a.provisioners.LoadByCertificate(crt)
if !ok { if !ok {
return &apiError{ return &apiError{
err: errors.New("provisioner not found"), err: errors.New("renew: provisioner not found"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
context: errContext, context: errContext,
} }
} }
if err := p.AuthorizeRenewal(crt); err != nil { if err := p.AuthorizeRenewal(crt); err != nil {
return &apiError{ return &apiError{
err: err, err: errors.Wrap(err, "renew"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
context: errContext, context: errContext,
} }

@ -1,14 +1,17 @@
package authority package authority
import ( import (
"crypto/x509"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"gopkg.in/square/go-jose.v2/jwt"
) )
func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
@ -43,18 +46,20 @@ func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func TestAuthorize(t *testing.T) { func TestAuthority_authorizeToken(t *testing.T) {
a := testAuthority(t) a := testAuthority(t)
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err) assert.FatalError(t, err)
// Invalid keys
keyNoKid := &jose.JSONWebKey{Key: key.Key, KeyID: ""}
keyBadKid := &jose.JSONWebKey{Key: key.Key, KeyID: "foo"}
now := time.Now() sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli" validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/sign"} validAudience := []string{"https://test.ca.smallstep.com/revoke"}
type authorizeTest struct { type authorizeTest struct {
auth *Authority auth *Authority
@ -63,83 +68,292 @@ func TestAuthorize(t *testing.T) {
res []interface{} res []interface{}
} }
tests := map[string]func(t *testing.T) *authorizeTest{ tests := map[string]func(t *testing.T) *authorizeTest{
"fail invalid ott": func(t *testing.T) *authorizeTest { "fail/invalid-ott": func(t *testing.T) *authorizeTest {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: "foo", ott: "foo",
err: &apiError{errors.New("authorize: error parsing token"), err: &apiError{errors.New("authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}}, http.StatusUnauthorized, context{"ott": "foo"}},
} }
}, },
"fail empty key id": func(t *testing.T) *authorizeTest { "fail/prehistoric-token": func(t *testing.T) *authorizeTest {
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyNoKid) cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
IssuedAt: jwt.NewNumericDate(now.Add(-time.Hour)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"), err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, context{"ott": raw}},
} }
}, },
"fail provisioner not found": func(t *testing.T) *authorizeTest { "fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyBadKid) cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
assert.FatalError(t, err)
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"), err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, context{"ott": raw}},
} }
}, },
"fail invalid issuer": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
raw, err := generateToken("test.smallstep.com", "invalid-issuer", validAudience[0], nil, now, key) cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
http.StatusUnauthorized, context{"ott": raw}},
} }
}, },
"fail empty subject": func(t *testing.T) *authorizeTest { "fail/token-already-used": func(t *testing.T) *authorizeTest {
raw, err := generateToken("", validIssuer, validAudience[0], nil, now, key) _a := testAuthority(t)
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
_, err = _a.authorizeToken(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: _a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorize: token subject cannot be empty"), err: &apiError{errors.New("authorizeToken: token already used"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, context{"ott": raw}},
} }
}, },
"fail verify-sig-failure": func(t *testing.T) *authorizeTest { }
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
p, err := tc.auth.authorizeToken(tc.ott)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, p.GetID(), "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
}
}
})
}
}
func TestAuthority_authorizeRevoke(t *testing.T) {
a := testAuthority(t)
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
type authorizeTest struct {
auth *Authority
opts *RevokeOptions
err error
res []interface{}
}
tests := map[string]func(t *testing.T) *authorizeTest{
"fail/token/invalid-ott": func(t *testing.T) *authorizeTest {
return &authorizeTest{
auth: a,
opts: &RevokeOptions{OTT: "foo"},
err: errors.New("authorizeRevoke: authorizeToken: error parsing token"),
}
},
"fail/token/invalid-subject": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
opts: &RevokeOptions{OTT: raw},
err: errors.New("authorizeRevoke: token subject cannot be empty"),
}
},
"ok/token": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
opts: &RevokeOptions{OTT: raw},
}
},
"fail/mTLS/invalid-serial": func(t *testing.T) *authorizeTest {
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: raw + "00", opts: &RevokeOptions{MTLS: true, Crt: crt, Serial: "foo"},
err: &apiError{errors.New("authorize: error parsing claims: square/go-jose: error in cryptographic primitive"), err: errors.New("authorizeRevoke: serial number in certificate different than body"),
http.StatusUnauthorized, context{"ott": raw + "00"}},
} }
}, },
"fail token-already-used": func(t *testing.T) *authorizeTest { "fail/mTLS/load-provisioner": func(t *testing.T) *authorizeTest {
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key) crt, err := pemutil.ReadCertificate("./testdata/certs/provisioner-not-found.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
_, err = a.Authorize(raw) return &authorizeTest{
auth: a,
opts: &RevokeOptions{MTLS: true, Crt: crt, Serial: "41633491264736369593451462439668497527"},
err: errors.New("authorizeRevoke: provisioner not found"),
}
},
"ok/mTLS": func(t *testing.T) *authorizeTest {
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
opts: &RevokeOptions{MTLS: true, Crt: crt, Serial: "102012593071130646873265215610956555026"},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
p, err := tc.auth.authorizeRevoke(tc.opts)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, p) {
assert.Equals(t, p.GetID(), "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
}
}
}
})
}
}
func TestAuthority_AuthorizeSign(t *testing.T) {
a := testAuthority(t)
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/sign"}
type authorizeTest struct {
auth *Authority
ott string
err *apiError
res []interface{}
}
tests := map[string]func(t *testing.T) *authorizeTest{
"fail/invalid-ott": func(t *testing.T) *authorizeTest {
return &authorizeTest{
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
}
},
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: raw, ott: raw,
err: &apiError{errors.New("authorize: token already used"), err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}}, http.StatusUnauthorized, context{"ott": raw}},
} }
}, },
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key) cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
ott: raw, ott: raw,
res: []interface{}{"1", "2", "3", "4", "5", "6"},
} }
}, },
} }
@ -147,11 +361,104 @@ func TestAuthorize(t *testing.T) {
for name, genTestCase := range tests { for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
got, err := tc.auth.AuthorizeSign(tc.ott)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Nil(t, got)
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 6, got)
}
}
})
}
}
// TODO: remove once Authorize deprecated.
func TestAuthority_Authorize(t *testing.T) {
a := testAuthority(t)
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/sign"}
type authorizeTest struct {
auth *Authority
ott string
err *apiError
res []interface{}
}
tests := map[string]func(t *testing.T) *authorizeTest{
"fail/invalid-ott": func(t *testing.T) *authorizeTest {
return &authorizeTest{
auth: a,
ott: "foo",
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
http.StatusUnauthorized, context{"ott": "foo"}},
}
},
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}},
}
},
"ok": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
ott: raw,
}
},
}
crtOpts, err := tc.auth.Authorize(tc.ott) for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
got, err := tc.auth.Authorize(tc.ott)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.Nil(t, got)
switch v := err.(type) { switch v := err.(type) {
case *apiError: case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.HasPrefix(t, v.err.Error(), tc.err.Error())
@ -163,9 +470,120 @@ func TestAuthorize(t *testing.T) {
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, len(crtOpts), len(tc.res)) assert.Len(t, 6, got)
} }
} }
}) })
} }
} }
func TestAuthority_authorizeRenewal(t *testing.T) {
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
assert.FatalError(t, err)
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
assert.FatalError(t, err)
otherCrt, err := pemutil.ReadCertificate("testdata/certs/provisioner-not-found.crt")
assert.FatalError(t, err)
type authorizeTest struct {
auth *Authority
crt *x509.Certificate
err *apiError
}
tests := map[string]func(t *testing.T) *authorizeTest{
"fail/db.IsRevoked-error": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &MockAuthDB{
isRevoked: func(key string) (bool, error) {
return false, errors.New("force")
},
}
return &authorizeTest{
auth: a,
crt: fooCrt,
err: &apiError{errors.New("renew: force"),
http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}},
}
},
"fail/revoked": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &MockAuthDB{
isRevoked: func(key string) (bool, error) {
return true, nil
},
}
return &authorizeTest{
auth: a,
crt: fooCrt,
err: &apiError{errors.New("renew: certificate has been revoked"),
http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}},
}
},
"fail/load-provisioner": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &MockAuthDB{
isRevoked: func(key string) (bool, error) {
return false, nil
},
}
return &authorizeTest{
auth: a,
crt: otherCrt,
err: &apiError{errors.New("renew: provisioner not found"),
http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}},
}
},
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &MockAuthDB{
isRevoked: func(key string) (bool, error) {
return false, nil
},
}
return &authorizeTest{
auth: a,
crt: renewDisabledCrt,
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}},
}
},
"ok": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &MockAuthDB{
isRevoked: func(key string) (bool, error) {
return false, nil
},
}
return &authorizeTest{
auth: a,
crt: fooCrt,
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
err := tc.auth.authorizeRenewal(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
) )
@ -44,6 +45,7 @@ type Config struct {
Address string `json:"address"` Address string `json:"address"`
DNSNames []string `json:"dnsNames"` DNSNames []string `json:"dnsNames"`
Logger json.RawMessage `json:"logger,omitempty"` Logger json.RawMessage `json:"logger,omitempty"`
DB *db.Config `json:"db,omitempty"`
Monitoring json.RawMessage `json:"monitoring,omitempty"` Monitoring json.RawMessage `json:"monitoring,omitempty"`
AuthorityConfig *AuthConfig `json:"authority,omitempty"` AuthorityConfig *AuthConfig `json:"authority,omitempty"`
TLS *tlsutil.TLSOptions `json:"tls,omitempty"` TLS *tlsutil.TLSOptions `json:"tls,omitempty"`
@ -59,7 +61,7 @@ type AuthConfig struct {
} }
// Validate validates the authority configuration. // Validate validates the authority configuration.
func (c *AuthConfig) Validate(audiences []string) error { func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
if c == nil { if c == nil {
return errors.New("authority cannot be undefined") return errors.New("authority cannot be undefined")
} }
@ -168,10 +170,18 @@ func (c *Config) Validate() error {
// getAudiences returns the legacy and possible urls without the ports that will // getAudiences returns the legacy and possible urls without the ports that will
// be used as the default provisioner audiences. The CA might have proxies in // be used as the default provisioner audiences. The CA might have proxies in
// front so we cannot rely on the port. // front so we cannot rely on the port.
func (c *Config) getAudiences() []string { func (c *Config) getAudiences() provisioner.Audiences {
audiences := []string{legacyAuthority} audiences := provisioner.Audiences{
Sign: []string{legacyAuthority},
Revoke: []string{legacyAuthority},
}
for _, name := range c.DNSNames { for _, name := range c.DNSNames {
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name)) audiences.Sign = append(audiences.Sign,
fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
audiences.Revoke = append(audiences.Revoke,
fmt.Sprintf("https://%s/revoke", name), fmt.Sprintf("https://%s/1.0/revoke", name))
} }
return audiences return audiences
} }

@ -277,7 +277,7 @@ func TestAuthConfigValidate(t *testing.T) {
ac: &AuthConfig{ ac: &AuthConfig{
Provisioners: p, Provisioners: p,
Claims: &provisioner.Claims{ Claims: &provisioner.Claims{
MinTLSDur: &provisioner.Duration{-1}, MinTLSDur: &provisioner.Duration{Duration: -1},
}, },
}, },
err: errors.New("claims: MinTLSCertDuration must be greater than 0"), err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
@ -305,7 +305,7 @@ func TestAuthConfigValidate(t *testing.T) {
for name, get := range tests { for name, get := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := get(t) tc := get(t)
err := tc.ac.Validate([]string{}) err := tc.ac.Validate(provisioner.Audiences{})
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error()) assert.Equals(t, tc.err.Error(), err.Error())

@ -0,0 +1,55 @@
package authority
import (
"crypto/x509"
"github.com/smallstep/certificates/db"
)
type MockAuthDB struct {
err error
ret1, ret2 interface{}
init func(*db.Config) (db.AuthDB, error)
isRevoked func(string) (bool, error)
revoke func(rci *db.RevokedCertificateInfo) error
storeCertificate func(crt *x509.Certificate) error
shutdown func() error
}
func (m *MockAuthDB) Init(c *db.Config) (db.AuthDB, error) {
if m.init != nil {
return m.init(c)
}
if m.ret1 == nil {
return nil, m.err
}
return m.ret1.(*db.DB), m.err
}
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
if m.isRevoked != nil {
return m.isRevoked(sn)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) Revoke(rci *db.RevokedCertificateInfo) error {
if m.revoke != nil {
return m.revoke(rci)
}
return m.err
}
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.storeCertificate != nil {
return m.storeCertificate(crt)
}
return m.err
}
func (m *MockAuthDB) Shutdown() error {
if m.shutdown != nil {
return m.shutdown()
}
return m.err
}

@ -38,12 +38,12 @@ type Collection struct {
byID *sync.Map byID *sync.Map
byKey *sync.Map byKey *sync.Map
sorted provisionerSlice sorted provisionerSlice
audiences []string audiences Audiences
} }
// NewCollection initializes a collection of provisioners. The given list of // NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner. // audiences are the audiences used by the JWT provisioner.
func NewCollection(audiences []string) *Collection { func NewCollection(audiences Audiences) *Collection {
return &Collection{ return &Collection{
byID: new(sync.Map), byID: new(sync.Map),
byKey: new(sync.Map), byKey: new(sync.Map),
@ -59,7 +59,7 @@ func (c *Collection) Load(id string) (Interface, bool) {
// LoadByToken parses the token claims and loads the provisioner associated. // LoadByToken parses the token claims and loads the provisioner associated.
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
// match with server audiences // match with server audiences
if matchesAudience(claims.Audience, c.audiences) { if matchesAudience(claims.Audience, c.audiences.All()) {
// If matches with stored audiences it will be a JWT token (default), and // If matches with stored audiences it will be a JWT token (default), and
// the id would be <issuer>:<kid>. // the id would be <issuer>:<kid>.
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)

@ -68,14 +68,14 @@ func TestCollection_LoadByToken(t *testing.T) {
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk) token, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, c1, err := parseToken(token) t1, c1, err := parseToken(token)
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err = decryptJSONWebKey(p2.EncryptedKey) jwk, err = decryptJSONWebKey(p2.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk) token, err = generateSimpleToken(p2.Name, testAudiences.Sign[1], jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
t2, c2, err := parseToken(token) t2, c2, err := parseToken(token)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -92,7 +92,7 @@ func TestCollection_LoadByToken(t *testing.T) {
type fields struct { type fields struct {
byID *sync.Map byID *sync.Map
audiences []string audiences Audiences
} }
type args struct { type args struct {
token *jose.JSONWebToken token *jose.JSONWebToken
@ -109,7 +109,7 @@ func TestCollection_LoadByToken(t *testing.T) {
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true}, {"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true}, {"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false}, {"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false}, {"fail", fields{byID, Audiences{Sign: []string{"https://foo"}}}, args{t1, c1}, nil, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -162,7 +162,7 @@ func TestCollection_LoadByCertificate(t *testing.T) {
type fields struct { type fields struct {
byID *sync.Map byID *sync.Map
audiences []string audiences Audiences
} }
type args struct { type args struct {
cert *x509.Certificate cert *x509.Certificate

@ -24,7 +24,7 @@ type JWK struct {
EncryptedKey string `json:"encryptedKey,omitempty"` EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer claimer *Claimer
audiences []string audiences Audiences
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -33,6 +33,25 @@ func (p *JWK) GetID() string {
return p.Name + ":" + p.Key.KeyID return p.Name + ":" + p.Key.KeyID
} }
//
// GetTokenID returns the identifier of the token.
func (p *JWK) GetTokenID(ott string) (string, error) {
// Validate payload
token, err := jose.ParseSigned(ott)
if err != nil {
return "", errors.Wrap(err, "error parsing token")
}
// Get claims w/out verification. We need to look up the provisioner
// key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner.
var claims jose.Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return "", errors.Wrap(err, "error verifying claims")
}
return claims.ID, nil
}
// GetName returns the name of the provisioner. // GetName returns the name of the provisioner.
func (p *JWK) GetName() string { func (p *JWK) GetName() string {
return p.Name return p.Name
@ -68,8 +87,10 @@ func (p *JWK) Init(config Config) (err error) {
return err return err
} }
// Authorize validates the given token. // authorizeToken performs common jwt authorization actions and returns the
func (p *JWK) Authorize(token string) ([]SignOption, error) { // claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request.
func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errors.Wrapf(err, "error parsing token")
@ -90,7 +111,7 @@ func (p *JWK) Authorize(token string) ([]SignOption, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errors.New("invalid token: invalid audience claim (aud)") return nil, errors.New("invalid token: invalid audience claim (aud)")
} }
@ -98,6 +119,22 @@ func (p *JWK) Authorize(token string) ([]SignOption, error) {
return nil, errors.New("token subject cannot be empty") return nil, errors.New("token subject cannot be empty")
} }
return &claims, nil
}
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
return err
}
// AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil {
return nil, err
}
// NOTE: This is for backwards compatibility with older versions of cli // NOTE: This is for backwards compatibility with older versions of cli
// and certificates. Older versions added the token subject as the only SAN // and certificates. Older versions added the token subject as the only SAN
// in a CSR by default. // in a CSR by default.
@ -123,9 +160,3 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
} }
return nil return nil
} }
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(token string) error {
return errors.New("not implemented")
}

@ -110,7 +110,7 @@ func TestJWK_Init(t *testing.T) {
} }
} }
func TestJWK_Authorize(t *testing.T) { func TestJWK_authorizeToken(t *testing.T) {
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateJWK() p2, err := generateJWK()
@ -121,11 +121,11 @@ func TestJWK_Authorize(t *testing.T) {
key2, err := decryptJSONWebKey(p2.EncryptedKey) key2, err := decryptJSONWebKey(p2.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1) t1, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], key1)
assert.FatalError(t, err) assert.FatalError(t, err)
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2) t2, err := generateSimpleToken(p2.Name, testAudiences.Sign[1], key2)
assert.FatalError(t, err) assert.FatalError(t, err)
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1) t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences.Sign[0], "", []string{}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// Invalid tokens // Invalid tokens
@ -133,14 +133,14 @@ func TestJWK_Authorize(t *testing.T) {
key3, err := generateJSONWebKey() key3, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
// missing key // missing key
failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3) failKey, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], key3)
assert.FatalError(t, err) assert.FatalError(t, err)
// invalid token // invalid token
failTok := "foo." + parts[1] + "." + parts[2] failTok := "foo." + parts[1] + "." + parts[2]
// invalid claims // invalid claims
failClaims := parts[0] + ".foo." + parts[1] failClaims := parts[0] + ".foo." + parts[1]
// invalid issuer // invalid issuer
failIss, err := generateSimpleToken("foobar", testAudiences[0], key1) failIss, err := generateSimpleToken("foobar", testAudiences.Sign[0], key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// invalid audience // invalid audience
failAud, err := generateSimpleToken(p1.Name, "foobar", key1) failAud, err := generateSimpleToken(p1.Name, "foobar", key1)
@ -148,13 +148,13 @@ func TestJWK_Authorize(t *testing.T) {
// invalid signature // invalid signature
failSig := t1[0 : len(t1)-2] failSig := t1[0 : len(t1)-2]
// no subject // no subject
failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1) failSub, err := generateToken("", p1.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// expired // expired
failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1) failExp, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// not before // not before
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1) failNbf, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// Remove encrypted key for p2 // Remove encrypted key for p2
@ -164,97 +164,156 @@ func TestJWK_Authorize(t *testing.T) {
token string token string
} }
tests := []struct { tests := []struct {
name string name string
prov *JWK prov *JWK
args args args args
wantErr bool err error
}{ }{
{"ok", p1, args{t1}, false}, {"fail-token", p1, args{failTok}, errors.New("error parsing token")},
{"ok-no-encrypted-key", p2, args{t2}, false}, {"fail-key", p1, args{failKey}, errors.New("error parsing claims")},
{"ok-no-sans", p1, args{t3}, false}, {"fail-claims", p1, args{failClaims}, errors.New("error parsing claims")},
{"fail-key", p1, args{failKey}, true}, {"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
{"fail-token", p1, args{failTok}, true}, {"fail-issuer", p1, args{failIss}, errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")},
{"fail-claims", p1, args{failClaims}, true}, {"fail-expired", p1, args{failExp}, errors.New("invalid token: square/go-jose/jwt: validation failed, token is expired (exp)")},
{"fail-issuer", p1, args{failIss}, true}, {"fail-not-before", p1, args{failNbf}, errors.New("invalid token: square/go-jose/jwt: validation failed, token not valid yet (nbf)")},
{"fail-audience", p1, args{failAud}, true}, {"fail-audience", p1, args{failAud}, errors.New("invalid token: invalid audience claim (aud)")},
{"fail-signature", p1, args{failSig}, true}, {"fail-subject", p1, args{failSub}, errors.New("token subject cannot be empty")},
{"fail-subject", p1, args{failSub}, true}, {"ok", p1, args{t1}, nil},
{"fail-expired", p1, args{failExp}, true}, {"ok-no-encrypted-key", p2, args{t2}, nil},
{"fail-not-before", p1, args{failNbf}, true}, {"ok-no-sans", p1, args{t3}, nil},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.Authorize(tt.args.token) if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
if (err != nil) != tt.wantErr { if assert.NotNil(t, tt.err) {
t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr) assert.HasPrefix(t, err.Error(), tt.err.Error())
return }
}
if err != nil {
assert.Nil(t, got)
} else { } else {
assert.Nil(t, tt.err)
assert.NotNil(t, got) assert.NotNil(t, got)
assert.Len(t, 6, got)
} }
}) })
} }
} }
func TestJWK_AuthorizeRenewal(t *testing.T) { func TestJWK_AuthorizeRevoke(t *testing.T) {
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateJWK() key1, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateSimpleToken(p1.Name, testAudiences.Revoke[0], key1)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
type args struct { type args struct {
cert *x509.Certificate token string
} }
tests := []struct { tests := []struct {
name string name string
prov *JWK prov *JWK
args args args args
wantErr bool err error
}{ }{
{"ok", p1, args{nil}, false}, {"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
{"fail", p2, args{nil}, true}, {"ok", p1, args{t1}, nil},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRevoke(tt.args.token); err != nil {
t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} }
}) })
} }
} }
func TestJWK_AuthorizeRevoke(t *testing.T) { func TestJWK_AuthorizeSign(t *testing.T) {
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
key1, err := decryptJSONWebKey(p1.EncryptedKey) key1, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
t1, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], key1)
assert.FatalError(t, err)
t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
type args struct { type args struct {
token string token string
} }
tests := []struct {
name string
prov *JWK
args args
err error
}{
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
{"ok-sans", p1, args{t1}, nil},
{"ok-no-sans", p1, args{t2}, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.AuthorizeSign(tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 6, got)
_cnv := got[0]
cnv, ok := _cnv.(commonNameValidator)
assert.True(t, ok)
assert.Equals(t, string(cnv), "subject")
_dnv := got[1]
dnv, ok := _dnv.(dnsNamesValidator)
assert.True(t, ok)
if tt.name == "ok-sans" {
assert.Equals(t, []string(dnv), []string{"test.smallstep.com"})
} else {
assert.Equals(t, []string(dnv), []string{"subject"})
}
}
}
})
}
}
func TestJWK_AuthorizeRenewal(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct { tests := []struct {
name string name string
prov *JWK prov *JWK
args args args args
wantErr bool wantErr bool
}{ }{
{"disabled", p1, args{t1}, true}, {"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }

@ -9,6 +9,10 @@ func (p *noop) GetID() string {
return "noop" return "noop"
} }
func (p *noop) GetTokenID(token string) (string, error) {
return "", nil
}
func (p *noop) GetName() string { func (p *noop) GetName() string {
return "noop" return "noop"
} }
@ -24,7 +28,7 @@ func (p *noop) Init(config Config) error {
return nil return nil
} }
func (p *noop) Authorize(token string) ([]SignOption, error) { func (p *noop) AuthorizeSign(token string) ([]SignOption, error) {
return []SignOption{}, nil return []SignOption{}, nil
} }

@ -21,7 +21,7 @@ func Test_noop(t *testing.T) {
assert.Equals(t, "", key) assert.Equals(t, "", key)
assert.Equals(t, false, ok) assert.Equals(t, false, ok)
sigOptions, err := p.Authorize("foo") sigOptions, err := p.AuthorizeSign("foo")
assert.Equals(t, []SignOption{}, sigOptions) assert.Equals(t, []SignOption{}, sigOptions)
assert.Equals(t, nil, err) assert.Equals(t, nil, err)
} }

@ -83,6 +83,25 @@ func (o *OIDC) GetID() string {
return o.ClientID return o.ClientID
} }
// GetTokenID returns the provisioner unique identifier, the OIDC provisioner the
// uses the clientID for this.
func (o *OIDC) GetTokenID(ott string) (string, error) {
// Validate payload
token, err := jose.ParseSigned(ott)
if err != nil {
return "", errors.Wrap(err, "error parsing token")
}
// Get claims w/out verification. We need to look up the provisioner
// key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner.
var claims openIDPayload
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return "", errors.Wrap(err, "error verifying claims")
}
return claims.Nonce, nil
}
// GetName returns the name of the provisioner. // GetName returns the name of the provisioner.
func (o *OIDC) GetName() string { func (o *OIDC) GetName() string {
return o.Name return o.Name
@ -171,8 +190,9 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
return nil return nil
} }
// Authorize validates the given token. // authorizeToken applies the most common provisioner authorization claims,
func (o *OIDC) Authorize(token string) ([]SignOption, error) { // leaving the rest to context specific methods.
func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errors.Wrapf(err, "error parsing token")
@ -201,6 +221,31 @@ func (o *OIDC) Authorize(token string) ([]SignOption, error) {
return nil, err return nil, err
} }
return &claims, nil
}
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
// Only tokens generated by an admin have the right to revoke a certificate.
func (o *OIDC) AuthorizeRevoke(token string) error {
claims, err := o.authorizeToken(token)
if err != nil {
return err
}
// Only admins can revoke certificates.
if o.IsAdmin(claims.Email) {
return nil
}
return errors.New("cannot revoke with non-admin token")
}
// AuthorizeSign validates the given token.
func (o *OIDC) AuthorizeSign(token string) ([]SignOption, error) {
claims, err := o.authorizeToken(token)
if err != nil {
return nil, err
}
// Admins should be able to authorize any SAN // Admins should be able to authorize any SAN
if o.IsAdmin(claims.Email) { if o.IsAdmin(claims.Email) {
return []SignOption{ return []SignOption{
@ -226,12 +271,6 @@ func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
return nil return nil
} }
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (o *OIDC) AuthorizeRevoke(token string) error {
return errors.New("not implemented")
}
func getAndDecode(uri string, v interface{}) error { func getAndDecode(uri string, v interface{}) error {
resp, err := http.Get(uri) resp, err := http.Get(uri)
if err != nil { if err != nil {

@ -122,7 +122,7 @@ func TestOIDC_Init(t *testing.T) {
} }
} }
func TestOIDC_Authorize(t *testing.T) { func TestOIDC_authorizeToken(t *testing.T) {
srv := generateJWKServer(2) srv := generateJWKServer(2)
defer srv.Close() defer srv.Close()
@ -153,12 +153,6 @@ func TestOIDC_Authorize(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1]) t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1])
assert.FatalError(t, err) assert.FatalError(t, err)
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email // Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
@ -202,8 +196,6 @@ func TestOIDC_Authorize(t *testing.T) {
}{ }{
{"ok1", p1, args{t1}, false}, {"ok1", p1, args{t1}, false},
{"ok2", p2, args{t2}, false}, {"ok2", p2, args{t2}, false},
{"admin", p3, args{t3}, false},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true}, {"fail-email", p3, args{failEmail}, true},
{"fail-domain", p3, args{failDomain}, true}, {"fail-domain", p3, args{failDomain}, true},
{"fail-key", p1, args{failKey}, true}, {"fail-key", p1, args{failKey}, true},
@ -217,7 +209,74 @@ func TestOIDC_Authorize(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.Authorize(tt.args.token) got, err := tt.prov.authorizeToken(tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else {
assert.NotNil(t, got)
assert.Equals(t, got.Issuer, "the-issuer")
}
})
}
}
func TestOIDC_AuthorizeSign(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p2.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok1", p1, args{t1}, false},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.AuthorizeSign(tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
fmt.Println(tt) fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
@ -237,6 +296,63 @@ func TestOIDC_Authorize(t *testing.T) {
} }
} }
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok1", p1, args{t1}, true},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRevoke(tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestOIDC_AuthorizeRenewal(t *testing.T) { func TestOIDC_AuthorizeRenewal(t *testing.T) {
p1, err := generateOIDC() p1, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -270,6 +386,7 @@ func TestOIDC_AuthorizeRenewal(t *testing.T) {
} }
} }
/*
func TestOIDC_AuthorizeRevoke(t *testing.T) { func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2) srv := generateJWKServer(2)
defer srv.Close() defer srv.Close()
@ -308,6 +425,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
}) })
} }
} }
*/
func Test_sanitizeEmail(t *testing.T) { func Test_sanitizeEmail(t *testing.T) {
tests := []struct { tests := []struct {

@ -11,15 +11,27 @@ import (
// Interface is the interface that all provisioner types must implement. // Interface is the interface that all provisioner types must implement.
type Interface interface { type Interface interface {
GetID() string GetID() string
GetTokenID(token string) (string, error)
GetName() string GetName() string
GetType() Type GetType() Type
GetEncryptedKey() (kid string, key string, ok bool) GetEncryptedKey() (kid string, key string, ok bool)
Init(config Config) error Init(config Config) error
Authorize(token string) ([]SignOption, error) AuthorizeSign(token string) ([]SignOption, error)
AuthorizeRenewal(cert *x509.Certificate) error AuthorizeRenewal(cert *x509.Certificate) error
AuthorizeRevoke(token string) error AuthorizeRevoke(token string) error
} }
// Audiences stores all supported audiences by request type.
type Audiences struct {
Sign []string
Revoke []string
}
// All returns all supported audiences across all request types in one list.
func (a *Audiences) All() []string {
return append(a.Sign, a.Revoke...)
}
// Type indicates the provisioner Type. // Type indicates the provisioner Type.
type Type int type Type int
@ -31,6 +43,11 @@ const (
// TypeOIDC is used to indicate the OIDC provisioners. // TypeOIDC is used to indicate the OIDC provisioners.
TypeOIDC Type = 2 TypeOIDC Type = 2
// RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map.
RevokeAudienceKey = "revoke"
// SignAudienceKey is the key for the 'sign' audiences in the audiences map.
SignAudienceKey = "sign"
) )
// Config defines the default parameters used in the initialization of // Config defines the default parameters used in the initialization of
@ -39,7 +56,7 @@ type Config struct {
// Claims are the default claims. // Claims are the default claims.
Claims Claims Claims Claims
// Audiences are the audiences used in the default provisioner, (JWK). // Audiences are the audiences used in the default provisioner, (JWK).
Audiences []string Audiences Audiences
} }
type provisioner struct { type provisioner struct {

@ -12,9 +12,9 @@ import (
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
var testAudiences = []string{ var testAudiences = Audiences{
"https://ca.smallstep.com/sign", Sign: []string{"https://ca.smallstep.com/sign", "https://ca.smallstep.com/1.0/sign"},
"https://ca.smallsteomcom/1.0/sign", Revoke: []string{"https://ca.smallstep.com/revoke", "https://ca.smallstep.com/1.0/revoke"},
} }
func must(args ...interface{}) []interface{} { func must(args ...interface{}) []interface{} {

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"crypto/x509"
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -23,3 +24,14 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
provisioners, nextCursor := a.provisioners.Find(cursor, limit) provisioners, nextCursor := a.provisioners.Find(cursor, limit)
return provisioners, nextCursor, nil return provisioners, nextCursor, nil
} }
// LoadProvisionerByCertificate returns an interface to the provisioner that
// provisioned the certificate.
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, context{}}
}
return p, nil
}

@ -48,7 +48,7 @@ func TestRoot(t *testing.T) {
} }
func TestAuthority_GetRootCertificate(t *testing.T) { func TestAuthority_GetRootCertificate(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -70,7 +70,7 @@ func TestAuthority_GetRootCertificate(t *testing.T) {
} }
func TestAuthority_GetRootCertificates(t *testing.T) { func TestAuthority_GetRootCertificates(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -92,7 +92,7 @@ func TestAuthority_GetRootCertificates(t *testing.T) {
} }
func TestAuthority_GetRoots(t *testing.T) { func TestAuthority_GetRoots(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -120,7 +120,7 @@ func TestAuthority_GetRoots(t *testing.T) {
} }
func TestAuthority_GetFederation(t *testing.T) { func TestAuthority_GetFederation(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -0,0 +1,14 @@
-----BEGIN CERTIFICATE-----
MIICIDCCAcagAwIBAgIQTL7pKDl8mFzRziotXbgjEjAKBggqhkjOPQQDAjAnMSUw
IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyMjIy
MjkyOVoXDTE5MDMyMzIyMjkyOVowHDEaMBgGA1UEAxMRZm9vLnNtYWxsc3RlcC5j
b20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQbptfDonFaeUPiTr52wl9r3dcz
greolwDRmsgyFgnr1EuKH56WRcgH1gjfL0pybFlO3PdgBukR4u+sveq343OAo4He
MIHbMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH
AwIwHQYDVR0OBBYEFP9pHiVlsx5mr4L2QirOb1G9Mo4jMB8GA1UdIwQYMBaAFKEe
9IdMyaHdURMjoJce7FN9HC9wMBwGA1UdEQQVMBOCEWZvby5zbWFsbHN0ZXAuY29t
MEwGDCsGAQQBgqRkxihAAQQ8MDoCAQEECHN0ZXAtY2xpBCs0VUVMSng4ZTBhUzlt
MENIM2ZaMEVCN0Q1YVVQSUNiNzU5ekFMSEZlanZjMAoGCCqGSM49BAMCA0gAMEUC
IDxtNo1BX/4Sbf/+k1n+v//kh8ETr3clPvhjcyfvBIGTAiEAiT0kvbkPdCCnmHIw
lhpgBwT5YReZzBwIYXyKyJXc07M=
-----END CERTIFICATE-----

@ -0,0 +1,15 @@
-----BEGIN CERTIFICATE-----
MIICTDCCAfGgAwIBAgIQH1JRmbStwdCkiuqf7SM8dzAKBggqhkjOPQQDAjAnMSUw
IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyMjIz
MDI0OVoXDTE5MDMyMzIzMDI0OVowLjEsMCoGA1UEAxMjcHJvdmlzaW9uZXItbm90
LWZvdW5kLnNtYWxsc3RlcC5jb20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARw
DOZEqgkXXY0PqnEvl5ADX4xXMDNgX4lraK8SP48Ljo3vUn5FqARjKaBgPLfowFkQ
gnjsAbBPwzt4SUWZW0ybo4H3MIH0MA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAU
BggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0OBBYEFDLOyjWD26FV5lfIwPqegYIt
PdmSMB8GA1UdIwQYMBaAFKEe9IdMyaHdURMjoJce7FN9HC9wMC4GA1UdEQQnMCWC
I3Byb3Zpc2lvbmVyLW5vdC1mb3VuZC5zbWFsbHN0ZXAuY29tMFMGDCsGAQQBgqRk
xihAAQRDMEECAQEED2dpZkBleGFtcGxlLmNvbQQrRVdDQThsdFJCdEwxN2VFQS1I
dW4zQWtCN0sxTERhUXItNkdvdXc3RXBoVTAKBggqhkjOPQQDAgNJADBGAiEAkaHR
dE706JI8eLio/AqPbH8A/qK1INlbKbrkZ03K5wECIQCqTGY4TYopJqLYt3HkQeTy
cJfHpuPfIzvpT8X0h3zlwQ==
-----END CERTIFICATE-----

@ -0,0 +1,14 @@
-----BEGIN CERTIFICATE-----
MIICJjCCAcygAwIBAgIQWhtLLuWC1foM7eq1jefkGDAKBggqhkjOPQQDAjAnMSUw
IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyNzIz
Mzk0M1oXDTE5MDMyODIzMzk0M1owHDEaMBgGA1UEAxMRYmF6LnNtYWxsc3RlcC5j
b20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATxC77uJiCHgxIoctoHZbEauQwV
1FStMSKnEQwNkm88GD0HVUcz3g9OEHJbdMuY7VJjefD2NfdMil2N1jOw8VzMo4Hk
MIHhMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH
AwIwHQYDVR0OBBYEFCEoFgFtPV3v3YsJt7uYoz7GgChEMB8GA1UdIwQYMBaAFKEe
9IdMyaHdURMjoJce7FN9HC9wMBwGA1UdEQQVMBOCEWJhei5zbWFsbHN0ZXAuY29t
MFIGDCsGAQQBgqRkxihAAQRCMEACAQEEDnJlbmV3X2Rpc2FibGVkBCtJTWk5NFdC
Tkk2Z1A1Y05IWGxaWU5VenZNakdkSHlCUm1Gb28tbENFYXFrMAoGCCqGSM49BAMC
A0gAMEUCIQD1uGcIQYdEEtVtOFWZGhDk+QJTznH5C182k74Kj/Ns3QIgeNtqYeto
Ur1bgN1pwEwjTyr4aNz+pUWHZhyodduVaCE=
-----END CERTIFICATE-----

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIJmnxm3N/ahRA2PWeZhRGJUKPU1lI44WcE4P1bynIim6oAoGCCqGSM49
AwEHoUQDQgAEG6bXw6JxWnlD4k6+dsJfa93XM4K3qJcA0ZrIMhYJ69RLih+elkXI
B9YI3y9KcmxZTtz3YAbpEeLvrL3qt+NzgA==
-----END EC PRIVATE KEY-----

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILWLnE+pkh9QQ0CcM89sCBAWMEK7EtoJOmHvvFpugj2joAoGCCqGSM49
AwEHoUQDQgAEcAzmRKoJF12ND6pxL5eQA1+MVzAzYF+Ja2ivEj+PC46N71J+RagE
YymgYDy36MBZEIJ47AGwT8M7eElFmVtMmw==
-----END EC PRIVATE KEY-----

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKmDvbNqeIZA9zssZxixJzAQBEUEBSyVnjCKvTWGMAd2oAoGCCqGSM49
AwEHoUQDQgAE8Qu+7iYgh4MSKHLaB2WxGrkMFdRUrTEipxEMDZJvPBg9B1VHM94P
ThByW3TLmO1SY3nw9jX3TIpdjdYzsPFczA==
-----END EC PRIVATE KEY-----

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/asn1" "encoding/asn1"
"encoding/base64"
"encoding/pem" "encoding/pem"
"net/http" "net/http"
"strings" "strings"
@ -11,6 +12,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
@ -111,6 +113,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
http.StatusInternalServerError, errContext} http.StatusInternalServerError, errContext}
} }
if err = a.db.StoreCertificate(serverCert); err != nil {
if err != db.ErrNotImplemented {
return nil, nil, &apiError{errors.Wrap(err, "sign: error storing certificate in db"),
http.StatusInternalServerError, errContext}
}
}
return serverCert, caCert, nil return serverCert, caCert, nil
} }
@ -194,6 +203,80 @@ func (a *Authority) Renew(oldCert *x509.Certificate) (*x509.Certificate, *x509.C
return serverCert, caCert, nil return serverCert, caCert, nil
} }
// RevokeOptions are the options for the Revoke API.
type RevokeOptions struct {
Serial string
Reason string
ReasonCode int
PassiveOnly bool
MTLS bool
Crt *x509.Certificate
OTT string
errCtxt map[string]interface{}
}
// Revoke revokes a certificate.
//
// NOTE: Only supports passive revocation - prevent existing certificates from
// being renewed.
//
// TODO: Add OCSP and CRL support.
func (a *Authority) Revoke(opts *RevokeOptions) error {
errContext := context{
"serialNumber": opts.Serial,
"reasonCode": opts.ReasonCode,
"reason": opts.Reason,
"passiveOnly": opts.PassiveOnly,
"mTLS": opts.MTLS,
}
if opts.MTLS {
errContext["certificate"] = base64.StdEncoding.EncodeToString(opts.Crt.Raw)
} else {
errContext["ott"] = opts.OTT
}
rci := &db.RevokedCertificateInfo{
Serial: opts.Serial,
ReasonCode: opts.ReasonCode,
Reason: opts.Reason,
MTLS: opts.MTLS,
RevokedAt: time.Now().UTC(),
}
// Authorize mTLS or token request and get back a provisioner interface.
p, err := a.authorizeRevoke(opts)
if err != nil {
return &apiError{errors.Wrap(err, "revoke"),
http.StatusUnauthorized, errContext}
}
// If not mTLS then get the TokenID of the token.
if !opts.MTLS {
rci.TokenID, err = p.GetTokenID(opts.OTT)
if err != nil {
return &apiError{errors.Wrap(err, "revoke: could not get ID for token"),
http.StatusInternalServerError, errContext}
}
errContext["tokenID"] = rci.TokenID
}
rci.ProvisionerID = p.GetID()
errContext["provisionerID"] = rci.ProvisionerID
err = a.db.Revoke(rci)
switch err {
case nil:
return nil
case db.ErrNotImplemented:
return &apiError{errors.New("revoke: no persistence layer configured"),
http.StatusNotImplemented, errContext}
case db.ErrAlreadyExists:
return &apiError{errors.Errorf("revoke: certificate with serial number %s has already been revoked", rci.Serial),
http.StatusBadRequest, errContext}
default:
return &apiError{err, http.StatusInternalServerError, errContext}
}
}
// GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server. // GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server.
func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
profile, err := x509util.NewLeafProfile("Step Online CA", profile, err := x509util.NewLeafProfile("Step Online CA",

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
@ -15,10 +16,13 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"gopkg.in/square/go-jose.v2/jwt"
) )
var ( var (
@ -199,8 +203,36 @@ func TestSign(t *testing.T) {
}, },
} }
}, },
"fail store cert in db": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
_a := testAuthority(t)
_a.db = &MockAuthDB{
storeCertificate: func(crt *x509.Certificate) error {
return &apiError{errors.New("force"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts}}
},
}
return &signTest{
auth: _a,
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: error storing certificate in db: force"),
http.StatusInternalServerError,
context{"csr": csr, "signOptions": signOpts},
},
}
},
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv) csr := getCSR(t, priv)
_a := testAuthority(t)
_a.db = &MockAuthDB{
storeCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
}
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
@ -350,7 +382,7 @@ func TestRenew(t *testing.T) {
} }
return &renewTest{ return &renewTest{
crt: crtNoRenew, crt: crtNoRenew,
err: &apiError{errors.New("renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), err: &apiError{errors.New("renew: renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, ctx}, http.StatusUnauthorized, ctx},
}, nil }, nil
}, },
@ -528,3 +560,230 @@ func TestGetTLSOptions(t *testing.T) {
}) })
} }
} }
func TestRevoke(t *testing.T) {
reasonCode := 2
reason := "bob was let go"
validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
now := time.Now().UTC()
getCtx := func() map[string]interface{} {
return context{
"serialNumber": "sn",
"reasonCode": reasonCode,
"reason": reason,
"mTLS": false,
"passiveOnly": false,
}
}
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
type test struct {
a *Authority
opts *RevokeOptions
err *apiError
}
tests := map[string]func() test{
"error/token/authorizeRevoke error": func() test {
a := testAuthority(t)
a.db = new(db.NoopDB)
ctx := getCtx()
ctx["ott"] = "foo"
return test{
a: a,
opts: &RevokeOptions{
OTT: "foo",
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
},
err: &apiError{errors.New("revoke: authorizeRevoke: authorizeToken: error parsing token"),
http.StatusUnauthorized, ctx},
}
},
"error/nil-db": func() test {
a := testAuthority(t)
a.db = new(db.NoopDB)
cl := jwt.Claims{
Subject: "sn",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{
a: a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: &apiError{errors.New("revoke: no persistence layer configured"),
http.StatusNotImplemented, ctx},
}
},
"error/db-revoke": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{err: errors.New("force")}
cl := jwt.Claims{
Subject: "sn",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{
a: a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: &apiError{errors.New("force"),
http.StatusInternalServerError, ctx},
}
},
"error/already-revoked": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{err: db.ErrAlreadyExists}
cl := jwt.Claims{
Subject: "sn",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{
a: a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: &apiError{errors.New("revoke: certificate with serial number sn has already been revoked"),
http.StatusBadRequest, ctx},
}
},
"ok/token": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{}
cl := jwt.Claims{
Subject: "sn",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return test{
a: a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
}
},
"error/mTLS/authorizeRevoke": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{}
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
ctx := getCtx()
ctx["certificate"] = base64.StdEncoding.EncodeToString(crt.Raw)
ctx["mTLS"] = true
return test{
a: a,
opts: &RevokeOptions{
Crt: crt,
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
MTLS: true,
},
err: &apiError{errors.New("revoke: authorizeRevoke: serial number in certificate different than body"),
http.StatusUnauthorized, ctx},
}
},
"ok/mTLS": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{}
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
return test{
a: a,
opts: &RevokeOptions{
Crt: crt,
Serial: "102012593071130646873265215610956555026",
ReasonCode: reasonCode,
Reason: reason,
MTLS: true,
},
}
},
}
for name, f := range tests {
tc := f()
t.Run(name, func(t *testing.T) {
if err := tc.a.Revoke(tc.opts); err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

@ -122,6 +122,9 @@ func (ca *CA) Run() error {
// Stop stops the CA calling to the server Shutdown method. // Stop stops the CA calling to the server Shutdown method.
func (ca *CA) Stop() error { func (ca *CA) Stop() error {
ca.renewer.Stop() ca.renewer.Stop()
if err := ca.auth.Shutdown(); err != nil {
return err
}
return ca.srv.Shutdown() return ca.srv.Shutdown()
} }

@ -346,6 +346,36 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
return &sign, nil return &sign, nil
} }
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
// struct.
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
var client *http.Client
if tr != nil {
client = &http.Client{Transport: tr}
} else {
client = c.client
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"})
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var revoke api.RevokeResponse
if err := readJSON(resp.Body, &revoke); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &revoke, nil
}
// Provisioners performs the provisioners request to the CA and returns the // Provisioners performs the provisioners request to the CA and returns the
// api.ProvisionersResponse struct with a map of provisioners. // api.ProvisionersResponse struct with a map of provisioners.
// //

@ -329,6 +329,81 @@ func TestClient_Sign(t *testing.T) {
} }
} }
func TestClient_Revoke(t *testing.T) {
ok := &api.RevokeResponse{Status: "ok"}
request := &api.RevokeRequest{
Serial: "sn",
OTT: "the-ott",
ReasonCode: 4,
}
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
request *api.RevokeRequest
response interface{}
responseCode int
wantErr bool
}{
{"ok", request, ok, 200, false},
{"unauthorized", request, unauthorized, 401, true},
{"nil request", nil, badRequest, 403, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
if !reflect.DeepEqual(body, &api.RevokeRequest{}) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
}
} else {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
}
}
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Revoke(tt.request, nil)
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Revoke() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Renew(t *testing.T) { func TestClient_Renew(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},

@ -0,0 +1,138 @@
package db
import (
"crypto/x509"
"encoding/json"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/nosql"
)
var (
revokedCertsTable = []byte("revoked_x509_certs")
certsTable = []byte("x509_certs")
)
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
// been previously set.
var ErrAlreadyExists = errors.New("already exists")
// Config represents the JSON attributes used for configuring a step-ca DB.
type Config struct {
Type string `json:"type"`
Path string `json:"path"`
}
// AuthDB is an interface over an Authority DB client that implements a nosql.DB interface.
type AuthDB interface {
IsRevoked(sn string) (bool, error)
Revoke(rci *RevokedCertificateInfo) error
StoreCertificate(crt *x509.Certificate) error
Shutdown() error
}
// DB is a wrapper over the nosql.DB interface.
type DB struct {
nosql.DB
}
// New returns a new database client that implements the AuthDB interface.
func New(c *Config) (AuthDB, error) {
if c == nil {
return new(NoopDB), nil
}
var db nosql.DB
switch strings.ToLower(c.Type) {
case "bbolt":
db = &nosql.BoltDB{}
if err := db.Open(c.Path); err != nil {
return nil, err
}
default:
return nil, errors.Errorf("unsupported db.type '%s'", c.Type)
}
tables := [][]byte{revokedCertsTable, certsTable}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s",
string(b))
}
}
return &DB{db}, nil
}
// RevokedCertificateInfo contains information regarding the certificate
// revocation action.
type RevokedCertificateInfo struct {
Serial string
ProvisionerID string
ReasonCode int
Reason string
RevokedAt time.Time
TokenID string
MTLS bool
}
// IsRevoked returns whether or not a certificate with the given identifier
// has been revoked.
// In the case of an X509 Certificate the `id` should be the Serial Number of
// the Certificate.
func (db *DB) IsRevoked(sn string) (bool, error) {
// If the DB is nil then act as pass through.
if db == nil {
return false, nil
}
// If the error is `Not Found` then the certificate has not been revoked.
// Any other error should be propagated to the caller.
if _, err := db.Get(revokedCertsTable, []byte(sn)); err != nil {
if nosql.IsErrNotFound(err) {
return false, nil
}
return false, errors.Wrap(err, "error checking revocation bucket")
}
// This certificate has been revoked.
return true, nil
}
// Revoke adds a certificate to the revocation table.
func (db *DB) Revoke(rci *RevokedCertificateInfo) error {
isRvkd, err := db.IsRevoked(rci.Serial)
if err != nil {
return err
}
if isRvkd {
return ErrAlreadyExists
}
rcib, err := json.Marshal(rci)
if err != nil {
return errors.Wrap(err, "error marshaling revoked certificate info")
}
if err = db.Set(revokedCertsTable, []byte(rci.Serial), rcib); err != nil {
return errors.Wrap(err, "database Set error")
}
return nil
}
// StoreCertificate stores a certificate PEM.
func (db *DB) StoreCertificate(crt *x509.Certificate) error {
if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil {
return errors.Wrap(err, "database Set error")
}
return nil
}
// Shutdown sends a shutdown message to the database.
func (db *DB) Shutdown() error {
if err := db.Close(); err != nil {
return errors.Wrap(err, "database shutdown error")
}
return nil
}

@ -0,0 +1,190 @@
package db
import (
"errors"
"testing"
"github.com/smallstep/assert"
"github.com/smallstep/nosql"
)
type MockNoSQLDB struct {
err error
ret1, ret2 interface{}
get func(bucket, key []byte) ([]byte, error)
set func(bucket, key, value []byte) error
open func(path string) error
close func() error
createTable func(bucket []byte) error
deleteTable func(bucket []byte) error
del func(bucket, key []byte) error
list func(bucket []byte) ([]*nosql.Entry, error)
update func(tx *nosql.Tx) error
}
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
if m.get != nil {
return m.get(bucket, key)
}
if m.ret1 == nil {
return nil, m.err
}
return m.ret1.([]byte), m.err
}
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
if m.set != nil {
return m.set(bucket, key, value)
}
return m.err
}
func (m *MockNoSQLDB) Open(path string) error {
if m.open != nil {
return m.open(path)
}
return m.err
}
func (m *MockNoSQLDB) Close() error {
if m.close != nil {
return m.close()
}
return m.err
}
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
if m.createTable != nil {
return m.createTable(bucket)
}
return m.err
}
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
if m.deleteTable != nil {
return m.deleteTable(bucket)
}
return m.err
}
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
if m.del != nil {
return m.del(bucket, key)
}
return m.err
}
func (m *MockNoSQLDB) List(bucket []byte) ([]*nosql.Entry, error) {
if m.list != nil {
return m.list(bucket)
}
return m.ret1.([]*nosql.Entry), m.err
}
func (m *MockNoSQLDB) Update(tx *nosql.Tx) error {
if m.update != nil {
return m.update(tx)
}
return m.err
}
func TestIsRevoked(t *testing.T) {
tests := map[string]struct {
key string
db *DB
isRevoked bool
err error
}{
"false/nil db": {
key: "sn",
},
"false/ErrNotFound": {
key: "sn",
db: &DB{&MockNoSQLDB{err: nosql.ErrNotFound, ret1: nil}},
},
"error/checking bucket": {
key: "sn",
db: &DB{&MockNoSQLDB{err: errors.New("force"), ret1: nil}},
err: errors.New("error checking revocation bucket: force"),
},
"true": {
key: "sn",
db: &DB{&MockNoSQLDB{ret1: []byte("value")}},
isRevoked: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
isRevoked, err := tc.db.IsRevoked(tc.key)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
assert.Fatal(t, isRevoked == tc.isRevoked)
}
})
}
}
func TestRevoke(t *testing.T) {
tests := map[string]struct {
rci *RevokedCertificateInfo
db *DB
err error
}{
"error/force isRevoked": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) {
return nil, errors.New("force IsRevoked")
},
}},
err: errors.New("error checking revocation bucket: force IsRevoked"),
},
"error/was already revoked": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) {
return nil, nil
},
}},
err: ErrAlreadyExists,
},
"error/database set": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) {
return nil, nosql.ErrNotFound
},
set: func(bucket []byte, key []byte, value []byte) error {
return errors.New("force")
},
}},
err: errors.New("database Set error: force"),
},
"ok": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) {
return nil, nosql.ErrNotFound
},
set: func(bucket []byte, key []byte, value []byte) error {
return nil
},
}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if err := tc.db.Revoke(tc.rci); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

@ -0,0 +1,38 @@
package db
import (
"crypto/x509"
"github.com/pkg/errors"
)
// ErrNotImplemented is an error returned when an operation is Not Implemented.
var ErrNotImplemented = errors.Errorf("not implemented")
// NoopDB implements the DB interface with Noops
type NoopDB int
// Init noop
func (n *NoopDB) Init(c *Config) (AuthDB, error) {
return n, nil
}
// IsRevoked noop
func (n *NoopDB) IsRevoked(sn string) (bool, error) {
return false, nil
}
// Revoke returns a "NotImplemented" error.
func (n *NoopDB) Revoke(rci *RevokedCertificateInfo) error {
return ErrNotImplemented
}
// StoreCertificate returns a "NotImplemented" error.
func (n *NoopDB) StoreCertificate(crt *x509.Certificate) error {
return ErrNotImplemented
}
// Shutdown returns nil
func (n *NoopDB) Shutdown() error {
return nil
}

@ -0,0 +1,21 @@
package db
import (
"testing"
"github.com/smallstep/assert"
)
func Test_noop(t *testing.T) {
db := new(NoopDB)
_db, err := db.Init(&Config{})
assert.FatalError(t, err)
assert.Equals(t, db, _db)
isRevoked, err := db.IsRevoked("foo")
assert.False(t, isRevoked)
assert.Nil(t, err)
assert.Equals(t, db.Revoke(&RevokedCertificateInfo{}), ErrNotImplemented)
}
Loading…
Cancel
Save