diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 00000000..64b4d103 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,56 @@ +name: Bug Report +description: File a bug report +title: "[Bug]: " +labels: ["bug", "needs triage"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + - type: textarea + id: steps + attributes: + label: Steps to Reproduce + description: Tell us how to reproduce this issue. + placeholder: These are the steps! + validations: + required: true + - type: textarea + id: your-env + attributes: + label: Your Environment + value: |- + * OS - + * `step-ca` Version - + validations: + required: true + - type: textarea + id: expected-behavior + attributes: + label: Expected Behavior + description: What did you expect to happen? + validations: + required: true + - type: textarea + id: actual-behavior + attributes: + label: Actual Behavior + description: What happens instead? + validations: + required: true + - type: textarea + id: context + attributes: + label: Additional Context + description: Add any other context about the problem here. + validations: + required: false + - type: textarea + id: contributing + attributes: + label: Contributing + value: | + Vote on this issue by adding a 👍 reaction. + To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already). + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 066c4951..00000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,27 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: '' -labels: bug, needs triage -assignees: '' - ---- - -### Subject of the issue -Describe your issue here. - -### Your environment -* OS - -* Version - - -### Steps to reproduce -Tell us how to reproduce this issue. Please provide a working demo, you can use [this template](https://plnkr.co/edit/XorWgI?p=preview) as a base. - -### Expected behaviour -Tell us what should happen - -### Actual behaviour -Tell us what happens instead - -### Additional context -Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/documentation-request.md b/.github/ISSUE_TEMPLATE/documentation-request.md index 86d15328..cf0250ae 100644 --- a/.github/ISSUE_TEMPLATE/documentation-request.md +++ b/.github/ISSUE_TEMPLATE/documentation-request.md @@ -1,12 +1,20 @@ --- name: Documentation Request about: Request documentation for a feature -title: '' -labels: documentation, needs triage +title: '[Docs]:' +labels: docs, needs triage assignees: '' --- +## Hello! + + +- Vote on this issue by adding a 👍 reaction +- If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.) + +## Affected area/feature + +- Vote on this issue by adding a 👍 reaction +- If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.) -### Why this is needed +## Issue details + + + +## Why is this needed? + + diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 266e9124..5d38f102 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -1,4 +1,20 @@ -### Description -Please describe your pull request. + +#### Name of feature: + +#### Pain or issue this feature alleviates: + +#### Why is this important to the project (if not answered above): + +#### Is there documentation on how to use this feature? If so, where? + +#### In what environments or workflows is this feature supported? + +#### In what environments or workflows is this feature explicitly NOT supported (if any)? + +#### Supporting links/other PRs/issues: 💔Thank you! diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2ab7084d..807cfdd6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -33,7 +33,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: 'v1.45.0' + version: 'v1.45.2' # Optional: working directory, useful for monorepos # working-directory: somedir @@ -139,7 +139,7 @@ jobs: name: Run GoReleaser uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0 with: - version: latest + version: 'v1.7.0' args: release --rm-dist env: GITHUB_TOKEN: ${{ secrets.PAT }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 64cb64cd..046589af 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: 'v1.45.0' + version: 'v1.45.2' # Optional: working directory, useful for monorepos # working-directory: somedir @@ -59,8 +59,9 @@ jobs: - name: Codecov if: matrix.go == '1.18' - uses: codecov/codecov-action@v1.2.1 + uses: codecov/codecov-action@v2 with: - file: ./coverage.out # optional + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.out # optional name: codecov-umbrella # optional fail_ci_if_error: true # optional (default = false) diff --git a/.goreleaser.yml b/.goreleaser.yml index 441d5785..7d57e657 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -230,42 +230,3 @@ scoop: # Your app's license # Default is empty. license: "Apache-2.0" - - #dockers: - # - dockerfile: docker/Dockerfile - # goos: linux - # goarch: amd64 - # use_buildx: true - # image_templates: - # - "smallstep/step-cli:latest" - # - "smallstep/step-cli:{{ .Tag }}" - # build_flag_templates: - # - "--platform=linux/amd64" - # - dockerfile: docker/Dockerfile - # goos: linux - # goarch: 386 - # use_buildx: true - # image_templates: - # - "smallstep/step-cli:latest" - # - "smallstep/step-cli:{{ .Tag }}" - # build_flag_templates: - # - "--platform=linux/386" - # - dockerfile: docker/Dockerfile - # goos: linux - # goarch: arm - # goarm: 7 - # use_buildx: true - # image_templates: - # - "smallstep/step-cli:latest" - # - "smallstep/step-cli:{{ .Tag }}" - # build_flag_templates: - # - "--platform=linux/arm/v7" - # - dockerfile: docker/Dockerfile - # goos: linux - # goarch: arm64 - # use_buildx: true - # image_templates: - # - "smallstep/step-cli:latest" - # - "smallstep/step-cli:{{ .Tag }}" - # build_flag_templates: - # - "--platform=linux/arm64/v8" diff --git a/CHANGELOG.md b/CHANGELOG.md index 49e4b15e..f6134aa5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,16 +4,70 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). -## [Unreleased - 0.18.3] - DATE +### TEMPLATE -- do not alter or remove +--- +## [x.y.z] - aaaa-bb-cc ### Added -- Added support for renew after expiry using the claim `allowRenewAfterExpiry`. +### Changed +### Deprecated +### Removed +### Fixed +### Security +--- + +## [Unreleased] +### Changed +- Certificates signed by an issuer using an RSA key will be signed using the same algorithm as the issuer certificate was signed with. The signature will no longer default to PKCS #1. For example, if the issuer certificate was signed using RSA-PSS with SHA-256, a new certificate will also be signed using RSA-PSS with SHA-256. + +## [0.20.0] - 2022-05-26 +### Added +- Added Kubernetes auth method for Vault RAs. +- Added support for reporting provisioners to linkedca. +- Added support for certificate policies on authority level. +- Added a Dockerfile with a step-ca build with HSM support. +- A few new WithXX methods for instantiating authorities +### Changed +- Context usage in HTTP APIs. +- Changed authentication for Vault RAs. +- Error message returned to client when authenticating with expired certificate. +- Strip padding from ACME CSRs. +### Deprecated +- HTTP API handler types. +### Fixed +- Fixed SSH revocation. +- CA client dial context for js/wasm target. +- Incomplete `extraNames` support in templates. +- SCEP GET request support. +- Large SCEP request handling. + +## [0.19.0] - 2022-04-19 +### Added +- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`. - Added support for `extraNames` in X.509 templates. +- Added `armv5` builds. +- Added RA support using a Vault instance as the CA. +- Added `WithX509SignerFunc` authority option. +- Added a new `/roots.pem` endpoint to download the CA roots in PEM format. +- Added support for Azure `Managed Identity` tokens. +- Added support for automatic configuration of linked RAs. +- Added support for the `--context` flag. It's now possible to start the + CA with `step-ca --context=abc` to use the configuration from context `abc`. + When a context has been configured and no configuration file is provided + on startup, the configuration for the current context is used. +- Added startup info logging and option to skip it (`--quiet`). +- Added support for renaming the CA (Common Name). ### Changed -- Made SCEP CA URL paths dynamic -- Support two latest versions of Go (1.17, 1.18) +- Made SCEP CA URL paths dynamic. +- Support two latest versions of Go (1.17, 1.18). +- Upgrade go.step.sm/crypto to v0.16.1. +- Upgrade go.step.sm/linkedca to v0.15.0. ### Deprecated +- Go 1.16 support. ### Removed ### Fixed +- Fixed admin credentials on RAs. +- Fixed ACME HTTP-01 challenges for IPv6 identifiers. +- Various improvements under the hood. ### Security ## [0.18.2] - 2022-03-01 @@ -49,7 +103,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Support for multiple certificate authority contexts. - Support for generating extractable keys and certificates on a pkcs#11 module. ### Changed -- Support two latest versions of golang (1.16, 1.17) +- Support two latest versions of Go (1.16, 1.17) ### Deprecated - go 1.15 support diff --git a/Makefile b/Makefile index 09e342df..906569f1 100644 --- a/Makefile +++ b/Makefile @@ -151,7 +151,7 @@ integration: bin/$(BINNAME) ######################################### fmt: - $Q gofmt -l -w $(SRC) + $Q gofmt -l -s -w $(SRC) lint: $Q golangci-lint run --timeout=30m diff --git a/README.md b/README.md index 5c29ccdf..1efeb4a9 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te - [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation - Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries - Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca) -- [Badger, BoltDB, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases) +- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases) ### ⚙️ Many ways to automate @@ -68,6 +68,7 @@ You can issue certificates in exchange for: - [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure - [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc. - A trusted X.509 certificate (X5C provisioner) +- A host certificate from your Nebula network - A SCEP challenge (SCEP provisioner) - An SSH host certificates needing renewal (the SSHPOP provisioner) - Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners) diff --git a/acme/account.go b/acme/account.go index 027d7be1..2dd412db 100644 --- a/acme/account.go +++ b/acme/account.go @@ -7,6 +7,8 @@ import ( "time" "go.step.sm/crypto/jose" + + "github.com/smallstep/certificates/authority/policy" ) // Account is a subset of the internal account type containing only those @@ -43,15 +45,63 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) { return base64.RawURLEncoding.EncodeToString(kid), nil } +// PolicyNames contains ACME account level policy names +type PolicyNames struct { + DNSNames []string `json:"dns"` + IPRanges []string `json:"ips"` +} + +// X509Policy contains ACME account level X.509 policy +type X509Policy struct { + Allowed PolicyNames `json:"allow"` + Denied PolicyNames `json:"deny"` + AllowWildcardNames bool `json:"allowWildcardNames"` +} + +// Policy is an ACME Account level policy +type Policy struct { + X509 X509Policy `json:"x509"` +} + +func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions { + if p == nil { + return nil + } + return &policy.X509NameOptions{ + DNSDomains: p.X509.Allowed.DNSNames, + IPRanges: p.X509.Allowed.IPRanges, + } +} +func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions { + if p == nil { + return nil + } + return &policy.X509NameOptions{ + DNSDomains: p.X509.Denied.DNSNames, + IPRanges: p.X509.Denied.IPRanges, + } +} + +// AreWildcardNamesAllowed returns if wildcard names +// like *.example.com are allowed to be signed. +// Defaults to false. +func (p *Policy) AreWildcardNamesAllowed() bool { + if p == nil { + return false + } + return p.X509.AllowWildcardNames +} + // ExternalAccountKey is an ACME External Account Binding key. type ExternalAccountKey struct { ID string `json:"id"` ProvisionerID string `json:"provisionerID"` Reference string `json:"reference"` AccountID string `json:"-"` - KeyBytes []byte `json:"-"` + HmacKey []byte `json:"-"` CreatedAt time.Time `json:"createdAt"` BoundAt time.Time `json:"boundAt,omitempty"` + Policy *Policy `json:"policy,omitempty"` } // AlreadyBound returns whether this EAK is already bound to @@ -68,6 +118,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error { } eak.AccountID = account.ID eak.BoundAt = time.Now() - eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once + eak.HmacKey = []byte{} // clearing the key bytes; can only be used once return nil } diff --git a/acme/account_test.go b/acme/account_test.go index 33524d87..edd1f5b0 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -7,8 +7,9 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/assert" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" ) func TestKeyToID(t *testing.T) { @@ -95,7 +96,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) { ID: "eakID", ProvisionerID: "provID", Reference: "ref", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, }, acct: &Account{ ID: "accountID", @@ -108,7 +109,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) { ID: "eakID", ProvisionerID: "provID", Reference: "ref", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, AccountID: "someAccountID", BoundAt: boundAt, }, @@ -138,7 +139,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) { assert.Equals(t, ae.Subproblems, tt.err.Subproblems) } else { assert.Equals(t, eak.AccountID, acct.ID) - assert.Equals(t, eak.KeyBytes, []byte{}) + assert.Equals(t, eak.HmacKey, []byte{}) assert.NotNil(t, eak.BoundAt) } }) diff --git a/acme/api/account.go b/acme/api/account.go index ade51aef..710747ca 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error { } // NewAccount is the handler resource for creating new ACME accounts. -func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { +func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } - eak, err := h.validateExternalAccountBinding(ctx, &nar) + eak, err := validateExternalAccountBinding(ctx, &nar) if err != nil { render.Error(w, err) return @@ -125,18 +128,17 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Contact: nar.Contact, Status: acme.StatusValid, } - if err := h.db.CreateAccount(ctx, acc); err != nil { + if err := db.CreateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response - err := eak.BindTo(acc) - if err != nil { + if err := eak.BindTo(acc); err != nil { render.Error(w, err) return } - if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { + if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } @@ -147,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - h.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID)) render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. -func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { +func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -187,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc.Contact = uar.Contact } - if err := h.db.UpdateAccount(ctx, acc); err != nil { + if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } } - h.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) render.JSON(w, acc) } @@ -210,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { +func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -222,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } - orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) + + orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, err) return } - h.linker.LinkOrdersByAccountID(ctx, orders) + linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, orders) logOrdersByAccount(w, orders) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 4c3404ec..d81553d2 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -13,10 +13,12 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + + "go.step.sm/crypto/jose" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" ) var ( @@ -29,6 +31,22 @@ var ( } ) +type fakeProvisioner struct{} + +func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { + return nil +} + +func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { + return nil, nil +} + +func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil } +func (*fakeProvisioner) GetID() string { return "" } +func (*fakeProvisioner) GetName() string { return "" } +func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 } +func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil } + func newProv() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ @@ -41,6 +59,19 @@ func newProv() acme.Provisioner { return p } +func newProvWithOptions(options *provisioner.Options) acme.Provisioner { + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + Options: options, + } + if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { + fmt.Printf("%v", err) + } + return p +} + func newACMEProv(t *testing.T) *provisioner.ACME { p := newProv() a, ok := p.(*provisioner.ACME) @@ -50,6 +81,15 @@ func newACMEProv(t *testing.T) *provisioner.ACME { return a } +func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME { + p := newProvWithOptions(options) + a, ok := p.(*provisioner.ACME) + if !ok { + t.Fatal("not a valid ACME provisioner") + } + return a +} + func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) { signer, err := jose.NewSigner( jose.SigningKey{ @@ -296,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = acme.NewProvisionerContext(ctx, prov) + ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { @@ -315,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetOrdersByAccountID(w, req) + GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -363,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -371,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -379,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/unmarshal-payload-error": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to "+ @@ -393,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -405,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) { b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -418,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -432,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -454,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), @@ -471,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ db: &acme.MockDB{ @@ -501,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - scepProvisioner := &provisioner.SCEP{ - Type: "SCEP", - Name: "test@scep-provisioner.com", - } - if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - assert.FatalError(t, err) - } ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), @@ -551,14 +590,13 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) eak := &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), } return test{ @@ -599,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -635,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) { Status: acme.StatusValid, Contact: []string{"foo", "bar"}, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, acc: acc, statusCode: 200, @@ -664,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = false ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -719,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -735,7 +770,7 @@ func TestHandler_NewAccount(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -759,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.NewAccount(w, req) + NewAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -814,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -822,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/nil-account": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -830,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -839,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -848,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), @@ -862,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -894,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -914,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -929,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -946,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -959,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetOrUpdateAccount(w, req) + GetOrUpdateAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/eab.go b/acme/api/eab.go index 3660d066..4c4fff04 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -4,8 +4,9 @@ import ( "context" "encoding/json" - "github.com/smallstep/certificates/acme" "go.step.sm/crypto/jose" + + "github.com/smallstep/certificates/acme" ) // ExternalAccountBinding represents the ACME externalAccountBinding JWS @@ -16,7 +17,7 @@ type ExternalAccountBinding struct { } // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. -func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { +func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") @@ -47,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc return nil, acmeErr } - externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) + db := acme.MustDatabaseFromContext(ctx) + externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { if _, ok := err.(*acme.Error); ok { return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") @@ -55,11 +57,19 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc return nil, acme.WrapErrorISE(err, "error retrieving external account key") } + if externalAccountKey == nil { + return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key") + } + + if len(externalAccountKey.HmacKey) == 0 { + return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID) + } + if externalAccountKey.AlreadyBound() { return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt) } - payload, err := eabJWS.Verify(externalAccountKey.KeyBytes) + payload, err := eabJWS.Verify(externalAccountKey.HmacKey) if err != nil { return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature") } @@ -97,12 +107,12 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool { // validateEABJWS verifies the contents of the External Account Binding JWS. // The protected header of the JWS MUST meet the following criteria: -// o The "alg" field MUST indicate a MAC-based algorithm -// o The "kid" field MUST contain the key identifier provided by the CA -// o The "nonce" field MUST NOT be present -// o The "url" field MUST be set to the same value as the outer JWS +// +// - The "alg" field MUST indicate a MAC-based algorithm +// - The "kid" field MUST contain the key identifier provided by the CA +// - The "nonce" field MUST NOT be present +// - The "url" field MUST be set to the same value as the outer JWS func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { - if jws == nil { return "", acme.NewErrorISE("no JWS provided") } diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go index dce9f36d..d2e596f9 100644 --- a/acme/api/eab_test.go +++ b/acme/api/eab_test.go @@ -9,10 +9,11 @@ import ( "time" "github.com/pkg/errors" + + "go.step.sm/crypto/jose" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" ) func Test_keysAreEqual(t *testing.T) { @@ -98,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -143,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ @@ -154,7 +153,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: createdAt, }, nil }, @@ -168,7 +167,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: createdAt, }, err: nil, @@ -189,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - scepProvisioner := &provisioner.SCEP{ - Type: "SCEP", - Name: "test@scep-provisioner.com", - } - if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - assert.FatalError(t, err) - } + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{}) return test{ ctx: ctx, err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), @@ -218,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -264,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{}, @@ -310,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -358,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -408,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -426,6 +413,112 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { err: acme.NewErrorISE("error retrieving external account key"), } }, + "fail/db.GetExternalAccountKey-nil": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = acme.NewProvisionerContext(ctx, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return nil, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"), + } + }, + "fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = acme.NewProvisionerContext(ctx, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + createdAt := time.Now() + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + CreatedAt: createdAt, + AccountID: "some-account-id", + HmacKey: []byte{}, + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"), + } + }, "fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) @@ -458,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -506,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() boundAt := time.Now().Add(1 * time.Second) @@ -520,6 +611,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { Reference: "testeak", CreatedAt: createdAt, AccountID: "some-account-id", + HmacKey: []byte{1, 3, 3, 7}, BoundAt: boundAt, }, nil }, @@ -565,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -575,7 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 2, 3, 4}, + HmacKey: []byte{1, 2, 3, 4}, CreatedAt: time.Now(), }, nil }, @@ -623,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -633,7 +723,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -678,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -688,7 +777,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -734,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -744,7 +832,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -762,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - db: tc.db, - } - got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar) + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) + got, err := validateExternalAccountBinding(ctx, tc.nar) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { @@ -787,7 +873,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { } else { assert.NotNil(t, tc.eak) assert.Equals(t, got.ID, tc.eak.ID) - assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, got.HmacKey, tc.eak.HmacKey) assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, got.Reference, tc.eak.Reference) assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt) diff --git a/acme/api/handler.go b/acme/api/handler.go index 10eb22cb..2e3931b1 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -2,12 +2,10 @@ package api import ( "context" - "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" - "net" "net/http" "time" @@ -16,6 +14,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) @@ -39,111 +38,152 @@ type payloadInfo struct { isEmptyJSON bool } -// Handler is the ACME API request handler. -type Handler struct { - db acme.DB - backdate provisioner.Duration - ca acme.CertificateAuthority - linker Linker - validateChallengeOptions *acme.ValidateChallengeOptions - prerequisitesChecker func(ctx context.Context) (bool, error) -} - // HandlerOptions required to create a new ACME API request handler. type HandlerOptions struct { - Backdate provisioner.Duration - // DB storage backend that impements the acme.DB interface. + // DB storage backend that implements the acme.DB interface. + // + // Deprecated: use acme.NewContex(context.Context, acme.DB) DB acme.DB + + // CA is the certificate authority interface. + // + // Deprecated: use authority.NewContext(context.Context, *authority.Authority) + CA acme.CertificateAuthority + + // Backdate is the duration that the CA will subtract from the current time + // to set the NotBefore in the certificate. + Backdate provisioner.Duration + // DNS the host used to generate accurate ACME links. By default the authority // will use the Host from the request, so this value will only be used if // request.Host is empty. DNS string + // Prefix is a URL path prefix under which the ACME api is served. This // prefix is required to generate accurate ACME links. // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // "acme" is the prefix from which the ACME api is accessed. Prefix string - CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) } +var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return authority.MustFromContext(ctx) +} + +// handler is the ACME API request handler. +type handler struct { + opts *HandlerOptions +} + +// Route traffic and implement the Router interface. For backward compatibility +// this route adds will add a new middleware that will set the ACME components +// on the context. +// +// Note: this method is deprecated in step-ca, other applications can still use +// this to support ACME, but the recommendation is to use use +// api.Route(api.Router) and acme.NewContext() instead. +func (h *handler) Route(r api.Router) { + client := acme.NewClient() + linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix) + route(r, func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker) + next(w, r.WithContext(ctx)) + } + }) +} + // NewHandler returns a new ACME API handler. -func NewHandler(ops HandlerOptions) api.RouterHandler { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - client := http.Client{ - Timeout: 30 * time.Second, - Transport: transport, - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - prerequisitesChecker := func(ctx context.Context) (bool, error) { - // by default all prerequisites are met - return true, nil - } - if ops.PrerequisitesChecker != nil { - prerequisitesChecker = ops.PrerequisitesChecker - } - return &Handler{ - ca: ops.CA, - db: ops.DB, - backdate: ops.Backdate, - linker: NewLinker(ops.DNS, ops.Prefix), - validateChallengeOptions: &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }, - prerequisitesChecker: prerequisitesChecker, +// +// Note: this method is deprecated in step-ca, other applications can still use +// this to support ACME, but the recommendation is to use use +// api.Route(api.Router) and acme.NewContext() instead. +func NewHandler(opts HandlerOptions) api.RouterHandler { + return &handler{ + opts: &opts, } } -// Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { - getPath := h.linker.GetUnescapedPathSuffix - // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) +// Route traffic and implement the Router interface. This method requires that +// all the acme components, authority, db, client, linker, and prerequisite +// checker to be present in the context. +func Route(r api.Router) { + route(r, nil) +} +func route(r api.Router, middleware func(next nextHTTP) nextHTTP) { + commonMiddleware := func(next nextHTTP) nextHTTP { + handler := func(w http.ResponseWriter, r *http.Request) { + // Linker middleware gets the provisioner and current url from the + // request and sets them in the context. + linker := acme.MustLinkerFromContext(r.Context()) + linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r) + } + if middleware != nil { + handler = middleware(handler) + } + return handler + } validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) + return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) + return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) + return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) + return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))) } - r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) - r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) - r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) - r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) - r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) + getPath := acme.GetUnescapedPathSuffix + + // Standard ACME API + r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + + r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"), + extractPayloadByJWK(NewAccount)) + r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(GetOrUpdateAccount)) + r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(NotImplemented)) + r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"), + extractPayloadByKid(NewOrder)) + r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(isPostAsGet(GetOrder))) + r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) + r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(FinalizeOrder)) + r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"), + extractPayloadByKid(isPostAsGet(GetAuthorization))) + r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), + extractPayloadByKid(GetChallenge)) + r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"), + extractPayloadByKid(isPostAsGet(GetCertificate))) + r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"), + extractPayloadByKidOrJWK(RevokeCert)) } // GetNonce just sets the right header since a Nonce is added to each response // by middleware by default. -func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { +func GetNonce(w http.ResponseWriter, r *http.Request) { if r.Method == "HEAD" { w.WriteHeader(http.StatusOK) } else { @@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) { // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. -func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { +func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { @@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { return } + linker := acme.MustLinkerFromContext(ctx) render.JSON(w, &Directory{ - NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), - NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), - NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), - RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), - KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), + NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), + NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), + NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), + RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType), + KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType), Meta: Meta{ ExternalAccountRequired: acmeProv.RequireEAB, }, @@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. -func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { +func NotImplemented(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. -func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { +func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) + az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return @@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } - if err = az.UpdateStatus(ctx, h.db); err != nil { + if err = az.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } - h.linker.LinkAuthorization(ctx, az) + linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. -func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { +func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // we'll just ignore the body. azID := chi.URLParam(r, "authzID") - ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) + ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return @@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { + if err = ch.Validate(ctx, db, jwk); err != nil { render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } - h.linker.LinkChallenge(ctx, ch, azID) + linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) - w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) + w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. -func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { +func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - certID := chi.URLParam(r, "certID") - cert, err := h.db.GetCertificate(ctx, certID) + certID := chi.URLParam(r, "certID") + cert, err := db.GetCertificate(ctx, certID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 67f7df30..822409df 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -19,11 +20,33 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + +func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return a + } +} + func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string @@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name - h.GetNonce(w, req) + GetNonce(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) { } func TestHandler_GetDirectory(t *testing.T) { - linker := NewLinker("ca.smallstep.com", "acme") + linker := acme.NewLinker("ca.smallstep.com", "acme") + _ = linker type test struct { ctx context.Context statusCode int @@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - ctx: ctx, + ctx: context.Background(), statusCode: 500, - err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + err: acme.NewErrorISE("provisioner is not in context"), } }, "fail/different-provisioner": func(t *testing.T) test { - prov := &provisioner.SCEP{ - Type: "SCEP", - Name: "test@scep-provisioner.com", - } - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{}) return test{ ctx: ctx, statusCode: 500, @@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov.RequireEAB = true provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: linker} + ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetDirectory(w, req) + GetDirectory(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, @@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { @@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetAuthorization(w, req) + GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetCertificate(w, req) + GetCertificate(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) { type test struct { db acme.DB - vco *acme.ValidateChallengeOptions + vc acme.Client ctx context.Context statusCode int ch *acme.Challenge @@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/db.GetChallenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/no-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, jwkContextKey, nil) @@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) { return acme.NewErrorISE("force") }, }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx = context.WithValue(ctx, jwkContextKey, &_pub) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) { URL: u, Error: acme.NewError(acme.ErrorConnectionType, "force"), }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetChallenge(w, req) + GetChallenge(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 10f7841f..5dcb93e3 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -9,7 +9,6 @@ import ( "net/url" "strings" - "github.com/go-chi/chi" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" @@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) { } } -// baseURLFromRequest determines the base URL which should be used for -// constructing link URLs in e.g. the ACME directory result by taking the -// request Host into consideration. -// -// If the Request.Host is an empty string, we return an empty string, to -// indicate that the configured URL values should be used instead. If this -// function returns a non-empty result, then this should be used in -// constructing ACME link URLs. -func baseURLFromRequest(r *http.Request) *url.URL { - // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go - // for an implementation that allows HTTP requests using the x-forwarded-proto - // header. - - if r.Host == "" { - return nil - } - return &url.URL{Scheme: "https", Host: r.Host} -} - -// baseURLFromRequest is a middleware that extracts and caches the baseURL -// from the request. -// E.g. https://ca.smallstep.com/ -func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) - next(w, r.WithContext(ctx)) - } -} - // addNonce is a middleware that adds a nonce to the response header. -func (h *Handler) addNonce(next nextHTTP) nextHTTP { +func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.db.CreateNonce(r.Context()) + db := acme.MustDatabaseFromContext(r.Context()) + nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, err) return @@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // addDirLink is a middleware that adds a 'Link' response reader with the // directory index url. -func (h *Handler) addDirLink(next nextHTTP) nextHTTP { +func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) + ctx := r.Context() + linker := acme.MustLinkerFromContext(ctx) + + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) next(w, r) } } // verifyContentType is a middleware that verifies that content type is // application/jose+json. -func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { +func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { render.Error(w, err) return } - u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} + u := &url.URL{ + Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), + } + + var expected []string if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} @@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { } // parseJWS is a middleware that parses a request body into a JSONWebSignature struct. -func (h *Handler) parseJWS(next nextHTTP) nextHTTP { +func parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -142,17 +119,19 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // The JWS Unprotected Header [RFC7515] MUST NOT be used // The JWS Payload MUST NOT be detached // The JWS Protected Header MUST include the following fields: -// * “alg” (Algorithm) -// * This field MUST NOT contain “none” or a Message Authentication Code -// (MAC) algorithm (e.g. one in which the algorithm registry description -// mentions MAC/HMAC). -// * “nonce” (defined in Section 6.5) -// * “url” (defined in Section 6.4) -// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below -func (h *Handler) validateJWS(next nextHTTP) nextHTTP { +// - “alg” (Algorithm). +// This field MUST NOT contain “none” or a Message Authentication Code +// (MAC) algorithm (e.g. one in which the algorithm registry description +// mentions MAC/HMAC). +// - “nonce” (defined in Section 6.5) +// - “url” (defined in Section 6.4) +// - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below +func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustDatabaseFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { } // Check the validity/freshness of the Nonce. - if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { + if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { render.Error(w, err) return } @@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // extractJWK is a middleware that extracts the JWK from the JWS and saves it // in the context. Make sure to parse and validate the JWS before running this // middleware. -func (h *Handler) extractJWK(next nextHTTP) nextHTTP { +func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustDatabaseFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx = context.WithValue(ctx, jwkContextKey, jwk) // Get Account OR continue to generate a new one OR continue Revoke with certificate private key - acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) + acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case errors.Is(err, acme.ErrNotFound): // For NewAccount and Revoke requests ... @@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { } } -// lookupProvisioner loads the provisioner associated with the request. -// Responds 404 if the provisioner does not exist. -func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") - name, err := url.PathUnescape(nameEscaped) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) - return - } - p, err := h.ca.LoadProvisionerByName(name) - if err != nil { - render.Error(w, err) - return - } - acmeProv, ok := p.(*provisioner.ACME) - if !ok { - render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) - return - } - ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) - next(w, r.WithContext(ctx)) - } -} - // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. -func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { +func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ok, err := h.prerequisitesChecker(ctx) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) - return - } - if !ok { - render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) - return + // If the function is not set assume that all prerequisites are met. + checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx) + if ok { + ok, err := checkFunc(ctx) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + return + } } - next(w, r.WithContext(ctx)) + next(w, r) } } // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { +func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return } - kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") + kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, acme.NewError(acme.ErrorMalformedType, @@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.db.GetAccount(ctx, accID) + acc, err := db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) @@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // extractOrLookupJWK forwards handling to either extractJWK or // lookupJWK based on the presence of a JWK or a KID, respectively. -func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { +func extractOrLookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { // and it can be used to check if a JWK exists. This flow is used when the ACME client // signed the payload with a certificate private key. if canExtractJWKFrom(jws) { - h.extractJWK(next)(w, r) + extractJWK(next)(w, r) return } // default to looking up the JWK based on KeyID. This flow is used when the ACME client // signed the payload with an account private key. - h.lookupJWK(next)(w, r) + lookupJWK(next)(w, r) } } @@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { +func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { } // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). -func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { +func isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { @@ -462,16 +424,12 @@ type ContextKey string const ( // accContextKey account key accContextKey = ContextKey("acc") - // baseURLContextKey baseURL key - baseURLContextKey = ContextKey("baseURL") // jwsContextKey jws key jwsContextKey = ContextKey("jws") // jwkContextKey jwk key jwkContextKey = ContextKey("jwk") // payloadContextKey payload key payloadContextKey = ContextKey("payload") - // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") ) // accountFromContext searches the context for an ACME account. Returns the @@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) { return val, nil } -// baseURLFromContext returns the baseURL if one is stored in the context. -func baseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(baseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - // jwkFromContext searches the context for a JWK. Returns the JWK or an error. func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) @@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { // provisionerFromContext searches the context for a provisioner. Returns the // provisioner or an error. func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { - val := ctx.Value(provisionerContextKey) - if val == nil { + p, ok := acme.ProvisionerFromContext(ctx) + if !ok || p == nil { return nil, acme.NewErrorISE("provisioner expected in request context") } - pval, ok := val.(acme.Provisioner) - if !ok || pval == nil { - return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") - } - return pval, nil + return p, nil } // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { - prov, err := provisionerFromContext(ctx) + p, err := provisionerFromContext(ctx) if err != nil { return nil, err } - acmeProv, ok := prov.(*provisioner.ACME) - if !ok || acmeProv == nil { + ap, ok := p.(*provisioner.ACME) + if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } - return acmeProv, nil + + return ap, nil } // payloadFromContext searches the context for a payload. Returns the payload diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 8003fa16..193f5347 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) { w.Write(testBody) } -func Test_baseURLFromRequest(t *testing.T) { - tests := []struct { - name string - targetURL string - expectedResult *url.URL - requestPreparer func(*http.Request) - }{ - { - "HTTPS host pass-through failed.", - "https://my.dummy.host", - &url.URL{Scheme: "https", Host: "my.dummy.host"}, - nil, - }, - { - "Port pass-through failed", - "https://host.with.port:8080", - &url.URL{Scheme: "https", Host: "host.with.port:8080"}, - nil, - }, - { - "Explicit host from Request.Host was not used.", - "https://some.target.host:8080", - &url.URL{Scheme: "https", Host: "proxied.host"}, - func(r *http.Request) { - r.Host = "proxied.host" - }, - }, - { - "Missing Request.Host value did not result in empty string result.", - "https://some.host", - nil, - func(r *http.Request) { - r.Host = "" - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("GET", tc.targetURL, nil) - if tc.requestPreparer != nil { - tc.requestPreparer(request) - } - result := baseURLFromRequest(request) - if result == nil || tc.expectedResult == nil { - assert.Equals(t, result, tc.expectedResult) - } else if result.String() != tc.expectedResult.String() { - t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String()) - } - }) - } -} - -func TestHandler_baseURLFromRequest(t *testing.T) { - h := &Handler{} - req := httptest.NewRequest("GET", "/foo", nil) - req.Host = "test.ca.smallstep.com:8080" - w := httptest.NewRecorder() - - next := func(w http.ResponseWriter, r *http.Request) { - bu := baseURLFromContext(r.Context()) - if assert.NotNil(t, bu) { - assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") - assert.Equals(t, bu.Scheme, "https") +func newBaseContext(ctx context.Context, args ...interface{}) context.Context { + for _, a := range args { + switch v := a.(type) { + case acme.DB: + ctx = acme.NewDatabaseContext(ctx, v) + case acme.Linker: + ctx = acme.NewLinkerContext(ctx, v) + case acme.PrerequisitesChecker: + ctx = acme.NewPrerequisitesCheckerContext(ctx, v) } } - - h.baseURLFromRequest(next)(w, req) - - req = httptest.NewRequest("GET", "/foo", nil) - req.Host = "" - - next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, baseURLFromContext(r.Context()), nil) - } - - h.baseURLFromRequest(next)(w, req) + return ctx } func TestHandler_addNonce(t *testing.T) { @@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", u, nil) + ctx := newBaseContext(context.Background(), tc.db) + req := httptest.NewRequest("GET", u, nil).WithContext(ctx) w := httptest.NewRecorder() - h.addNonce(testNext)(w, req) + addNonce(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { link string - linker Linker statusCode int ctx context.Context err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) return test{ - linker: NewLinker("dns", "acme"), ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, @@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.addDirLink(testNext)(w, req) + addDirLink(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { - h Handler ctx context.Context contentType string err *acme.Error @@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/provisioner-not-set": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, url: u, ctx: context.Background(), contentType: "foo", @@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, url: u, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), @@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), @@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) { }, "ok": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkix-cert", statusCode: 200, } }, "ok/certificate/jose+json": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) { req = req.WithContext(tc.ctx) req.Header.Add("Content-Type", tc.contentType) w := httptest.NewRecorder() - tc.h.verifyContentType(testNext)(w, req) + verifyContentType(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.isPostAsGet(testNext)(w, req) + isPostAsGet(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, tc.body) w := httptest.NewRecorder() - h.parseJWS(tc.next)(w, req) + parseJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.verifyAndExtractJWSPayload(tc.next)(w, req) + verifyAndExtractJWSPayload(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - linker Linker + linker acme.Linker db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) @@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), @@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _parsed) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) @@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) { } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.lookupJWK(tc.next)(w, req) + lookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), @@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.extractJWK(tc.next)(w, req) + extractJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/nil-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/no-signature": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), @@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), @@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), @@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), @@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), @@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.validateJWS(tc.next)(w, req) + validateJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { u := "https://ca.smallstep.com/acme/account" type test struct { db acme.DB - linker Linker + linker acme.Linker statusCode int ctx context.Context err *acme.Error @@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) @@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ - linker: NewLinker("test.ca.smallstep.com", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, acc.ID) @@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.extractOrLookupJWK(tc.next)(w, req) + extractOrLookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) { u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) type test struct { - linker Linker + linker acme.Linker ctx context.Context prerequisitesChecker func(context.Context) (bool, error) next func(http.ResponseWriter, *http.Request) @@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "fail/prerequisites-nok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.checkPrerequisites(tc.next)(w, req) + checkPrerequisites(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/order.go b/acme/api/order.go index 99eb0e95..679fe32f 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -16,6 +16,8 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" ) // NewOrderRequest represents the body for a NewOrder request. @@ -37,6 +39,8 @@ func (n *NewOrderRequest) Validate() error { if id.Type == acme.IP && net.ParseIP(id.Value) == nil { return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) } + // TODO(hs): add some validations for DNS domains? + // TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 } return nil } @@ -50,7 +54,13 @@ type FinalizeRequest struct { // Validate validates a finalize request body. func (f *FinalizeRequest) Validate() error { var err error - csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) + // RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the + // CSR specifically, instead of "normal" base64-url encoding (incl. padding). + // By trimming the padding from CSRs submitted by ACME clients that use + // base64-url encoding instead of raw base64-url encoding, these are also + // supported. This was reported in https://github.com/smallstep/certificates/issues/939 + // to be the case for a Synology DSM NAS system. + csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "=")) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr") } @@ -68,8 +78,12 @@ var defaultOrderExpiry = time.Hour * 24 var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. -func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { +func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ca := mustAuthority(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -85,6 +99,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } + var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, @@ -97,6 +112,48 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } + // TODO(hs): gather all errors, so that we can build one response with ACME subproblems + // include the nor.Validate() error here too, like in the example in the ACME RFC? + + acmeProv, err := acmeProvisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } + + var eak *acme.ExternalAccountKey + if acmeProv.RequireEAB { + if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { + render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key")) + return + } + } + + acmePolicy, err := newACMEPolicyEngine(eak) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine")) + return + } + + for _, identifier := range nor.Identifiers { + // evaluate the ACME account level policy + if err = isIdentifierAllowed(acmePolicy, identifier); err != nil { + render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + // evaluate the provisioner level policy + orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} + if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil { + render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + // evaluate the authority level policy + if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { + render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + } + now := clock.Now() // New order. o := &acme.Order{ @@ -117,7 +174,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ExpiresAt: o.ExpiresAt, Status: acme.StatusPending, } - if err := h.newAuthorization(ctx, az); err != nil { + if err := newAuthorization(ctx, az); err != nil { render.Error(w, err) return } @@ -136,18 +193,32 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } - if err := h.db.CreateOrder(ctx, o); err != nil { + if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } - h.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSONStatus(w, o, http.StatusCreated) } -func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { +func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error { + if acmePolicy == nil { + return nil + } + return acmePolicy.AreSANsAllowed([]string{identifier.Value}) +} + +func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) { + if eak == nil { + return nil, nil + } + return policy.NewX509PolicyEngine(eak.Policy) +} + +func newAuthorization(ctx context.Context, az *acme.Authorization) error { if strings.HasPrefix(az.Identifier.Value, "*.") { az.Wildcard = true az.Identifier = acme.Identifier{ @@ -163,6 +234,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } + + db := acme.MustDatabaseFromContext(ctx) az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ @@ -172,20 +245,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) Token: az.Token, Status: acme.StatusPending, } - if err := h.db.CreateChallenge(ctx, ch); err != nil { + if err := db.CreateChallenge(ctx, ch); err != nil { return acme.WrapErrorISE(err, "error creating challenge") } az.Challenges[i] = ch } - if err = h.db.CreateAuthorization(ctx, az); err != nil { + if err = db.CreateAuthorization(ctx, az); err != nil { return acme.WrapErrorISE(err, "error creating authorization") } return nil } // GetOrder ACME api for retrieving an order. -func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { +func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -196,7 +272,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -211,20 +288,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.UpdateStatus(ctx, h.db); err != nil { + if err = o.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } - h.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } -// FinalizeOrder attemptst to finalize an order and create a certificate. -func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { +// FinalizeOrder attempts to finalize an order and create a certificate. +func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -251,7 +331,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -266,14 +346,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + + ca := mustAuthority(ctx) + if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } - h.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 1ce034e7..7f67c72e 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -16,9 +16,13 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + + "go.step.sm/crypto/pemutil" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "go.step.sm/crypto/pemutil" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" ) func TestNewOrderRequest_Validate(t *testing.T) { @@ -206,6 +210,13 @@ func TestFinalizeRequestValidate(t *testing.T) { }, } }, + "ok/padding": func(t *testing.T) test { + return test{ + fr: &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw) + "==", // add intentional padding + }, + } + }, } for name, run := range tests { tc := run(t) @@ -276,15 +287,17 @@ func TestHandler_GetOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -294,6 +307,7 @@ func TestHandler_GetOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -301,9 +315,10 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -311,7 +326,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -325,7 +340,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -341,7 +356,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -357,7 +372,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/order-update-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -381,10 +396,9 @@ func TestHandler_GetOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { @@ -421,11 +435,11 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetOrder(w, req) + GetOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -636,8 +650,8 @@ func TestHandler_newAuthorization(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - h := &Handler{db: tc.db} - if err := h.newAuthorization(context.Background(), tc.az); err != nil { + ctx := newBaseContext(context.Background(), tc.db) + if err := newAuthorization(ctx, tc.az); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *acme.Error: @@ -667,6 +681,7 @@ func TestHandler_NewOrder(t *testing.T) { baseURL.String(), escProvName) type test struct { + ca acme.CertificateAuthority db acme.DB ctx context.Context nor *NewOrderRequest @@ -677,15 +692,17 @@ func TestHandler_NewOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -695,6 +712,7 @@ func TestHandler_NewOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -702,9 +720,10 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -713,8 +732,9 @@ func TestHandler_NewOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -722,21 +742,23 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("paylod does not exist"), + err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), @@ -747,15 +769,232 @@ func TestHandler_NewOrder(t *testing.T) { fr := &NewOrderRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, + "fail/acmeProvisionerFromContext-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{}) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, errors.New("force") + }, + }, + err: acme.NewErrorISE("error retrieving external account binding key: force"), + } + }, + "fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test { + acmeProv := newACMEProv(t) + acmeProv.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), acmeProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, errors.New("force") + }, + }, + err: acme.NewErrorISE("error retrieving external account binding key: force"), + } + }, + "fail/newACMEPolicyEngine-error": func(t *testing.T) test { + acmeProv := newACMEProv(t) + acmeProv.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), acmeProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"**.local"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewErrorISE("error creating ACME policy engine"), + } + }, + "fail/isIdentifierAllowed-error": func(t *testing.T) test { + acmeProv := newACMEProv(t) + acmeProv.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), acmeProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.local"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), + } + }, + "fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test { + options := &provisioner.Options{ + X509: &provisioner.X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + } + provWithPolicy := newACMEProvWithOptions(t, options) + provWithPolicy.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.internal"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), + } + }, + "fail/ca.AreSANsAllowed-error": func(t *testing.T) test { + options := &provisioner.Options{ + X509: &provisioner.X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.internal"}, + }, + }, + } + provWithPolicy := newACMEProvWithOptions(t, options) + provWithPolicy.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + ca: &mockCA{ + MockAreSANsallowed: func(ctx context.Context, sans []string) error { + return errors.New("force: not authorized by authority") + }, + }, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.internal"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), + } + }, "fail/error-h.newAuthorization": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ @@ -765,12 +1004,13 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.AccountID, "accID") @@ -780,6 +1020,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, ch.Value, "zap.internal") return errors.New("force") }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, err: acme.NewErrorISE("error creating challenge: force"), } @@ -793,7 +1038,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( @@ -804,6 +1049,7 @@ func TestHandler_NewOrder(t *testing.T) { return test{ ctx: ctx, statusCode: 500, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -849,6 +1095,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return errors.New("force") }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, err: acme.NewErrorISE("error creating order: force"), } @@ -863,10 +1114,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3, ch4 **acme.Challenge az1ID, az2ID *string @@ -876,6 +1126,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch chCount { @@ -945,6 +1196,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() @@ -978,10 +1234,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -991,6 +1246,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1037,6 +1293,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() @@ -1070,10 +1331,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1083,6 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1129,6 +1390,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() @@ -1161,10 +1427,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1174,6 +1439,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1220,6 +1486,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { testBufferDur := 5 * time.Second @@ -1253,10 +1524,109 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + ca: &mockCA{}, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, acme.DNS01) + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, acme.HTTP01) + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, acme.TLSALPN01) + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/default-naf-nbf-with-policy": func(t *testing.T) test { + options := &provisioner.Options{ + X509: &provisioner.X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.internal"}, + }, + }, + } + provWithPolicy := newACMEProvWithOptions(t, options) + provWithPolicy.RequireEAB = true + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1266,6 +1636,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1312,10 +1683,18 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + expNaf := now.Add(prov.DefaultTLSCertDuration()) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) @@ -1334,11 +1713,12 @@ func TestHandler_NewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + mockMustAuthority(t, tc.ca) + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.NewOrder(w, req) + NewOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1371,6 +1751,7 @@ func TestHandler_NewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { + mockMustAuthority(t, &mockCA{}) prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -1429,15 +1810,17 @@ func TestHandler_FinalizeOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -1447,6 +1830,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1454,9 +1838,10 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1465,8 +1850,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -1474,21 +1860,23 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("paylod does not exist"), + err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), @@ -1499,10 +1887,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), @@ -1511,7 +1900,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1526,7 +1915,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1543,7 +1932,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1560,7 +1949,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/order-finalize-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1585,10 +1974,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -1624,11 +2012,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.FinalizeOrder(w, req) + FinalizeOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 4b71bc22..a8b98f3f 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -26,9 +26,11 @@ type revokePayload struct { } // RevokeCert attempts to revoke a certificate. -func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { - +func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) @@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } serial := certToBeRevoked.SerialNumber.String() - dbCert, err := h.db.GetCertificateBySerial(ctx, serial) + dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return @@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) + acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { render.Error(w, acmeErr) return @@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } } - hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) + ca := mustAuthority(ctx) + hasBeenRevokedBefore, err := ca.IsRevoked(serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return @@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } options := revokeOptions(serial, certToBeRevoked, reasonCode) - err = h.ca.Revoke(ctx, options) + err = ca.Revoke(ctx, options) if err != nil { render.Error(w, wrapRevokeErr(err)) return } logRevoke(w, options) - w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) w.Write(nil) } @@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { // the identifiers in the certificate are extracted and compared against the (valid) Authorizations // that are stored for the ACME Account. If these sets match, the Account is considered authorized // to revoke the certificate. If this check fails, the client will receive an unauthorized error. -func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { +func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { if !account.IsValid() { return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) } diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 4ff54405..240ac748 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -24,14 +24,16 @@ import ( "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "golang.org/x/crypto/ocsp" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/x509util" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ocsp" ) // v is a utility function to return the pointer to an integer @@ -274,14 +276,22 @@ func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error } type mockCA struct { - MockIsRevoked func(sn string) (bool, error) - MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error + MockIsRevoked func(sn string) (bool, error) + MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error + MockAreSANsallowed func(ctx context.Context, sans []string) error } func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return nil, nil } +func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { + if m.MockAreSANsallowed != nil { + return m.MockAreSANsallowed(ctx, sans) + } + return nil +} + func (m *mockCA) IsRevoked(sn string) (bool, error) { if m.MockIsRevoked != nil { return m.MockIsRevoked(sn) @@ -511,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-jws": func(t *testing.T) test { ctx := context.Background() return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -519,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/nil-jws": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -527,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -534,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, nil) + ctx = acme.NewProvisionerContext(ctx, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -543,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -552,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -563,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/unmarshal-payload": func(t *testing.T) test { malformedPayload := []byte(`{"payload":malformed?}`) ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("error unmarshaling payload"), @@ -577,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) { } wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -596,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) { } emptyPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -610,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/db.GetCertificateBySerial": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -628,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/different-certificate-contents": func(t *testing.T) test { aDifferentCert, _, err := generateCertKeyPair() assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -647,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -666,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, accContextKey, nil) @@ -687,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -717,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-authorized": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -771,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -798,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-revoked-check-fails": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -832,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -870,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) { invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -908,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) + ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -940,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -972,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -1003,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "ok/using-account-key": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1031,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1057,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) + mockMustAuthority(t, tc.ca) req := httptest.NewRequest("POST", revokeURL, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.RevokeCert(w, req) + RevokeCert(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1198,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} - acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) + // h := &Handler{db: tc.db} + acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) expectError := tc.err != nil gotError := acmeErr != nil diff --git a/acme/challenge.go b/acme/challenge.go index 0e1994e4..8d8466bd 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net" - "net/http" "net/url" "reflect" "strings" @@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error { // If already valid or invalid then return without performing validation. if ch.Status != StatusPending { return nil } switch ch.Type { case HTTP01: - return http01Validate(ctx, ch, db, jwk, vo) + return http01Validate(ctx, ch, db, jwk) case DNS01: - return dns01Validate(ctx, ch, db, jwk, vo) + return dns01Validate(ctx, ch, db, jwk) case TLSALPN01: - return tlsalpn01Validate(ctx, ch, db, jwk, vo) + return tlsalpn01Validate(ctx, ch, db, jwk) default: return NewErrorISE("unexpected challenge type '%s'", ch.Type) } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { - u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { + u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - resp, err := vo.HTTPGet(u.String()) + vc := MustClientFromContext(ctx) + resp, err := vc.Get(u.String()) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", u)) @@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return nil } +// http01ChallengeHost checks if a Challenge value is an IPv6 address +// and adds square brackets if that's the case, so that it can be used +// as a hostname. Returns the original Challenge value as the host to +// use in other cases. +func http01ChallengeHost(value string) string { + if ip := net.ParseIP(value); ip != nil && ip.To4() == nil { + value = "[" + value + "]" + } + return value +} + func tlsAlert(err error) uint8 { var opErr *net.OpError if errors.As(err, &opErr) { @@ -130,7 +141,7 @@ func tlsAlert(err error) uint8 { return 0 } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, // https://tools.ietf.org/html/rfc8737#section-4 @@ -143,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.TLSDial("tcp", hostPort, config) + vc := MustClientFromContext(ctx) + conn, err := vc.TLSDial("tcp", hostPort, config) if err != nil { // With Go 1.17+ tls.Dial fails if there's no overlap between configured // client and server protocols. When this happens the connection is @@ -242,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) + vc := MustClientFromContext(ctx) + txtRecords, err := vc.LookupTxt("_acme-challenge." + domain) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) @@ -365,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err } return nil } - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -// ValidateChallengeOptions are ACME challenge validator functions. -type ValidateChallengeOptions struct { - HTTPGet httpGetter - LookupTxt lookupTxt - TLSDial tlsDialer -} diff --git a/acme/challenge_test.go b/acme/challenge_test.go index d8ce4d76..e1b6816a 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -13,6 +13,7 @@ import ( "encoding/asn1" "encoding/base64" "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -23,11 +24,23 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + func Test_storeError(t *testing.T) { type test struct { ch *Challenge @@ -228,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) { func TestChallenge_Validate(t *testing.T) { type test struct { ch *Challenge - vo *ValidateChallengeOptions + vc Client jwk *jose.JSONWebKey db DB srv *httptest.Server @@ -272,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -308,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -343,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -380,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -415,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) { } return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -465,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -492,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) { defer tc.srv.Close() } - if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -523,7 +537,7 @@ func (errReader) Close() error { func TestHTTP01Validate(t *testing.T) { type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -540,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -574,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -607,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -644,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -680,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil @@ -703,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) { jwk.Key = "foo" return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -729,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -771,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -814,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -856,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -886,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -910,7 +925,7 @@ func TestDNS01Validate(t *testing.T) { fulldomain := "*.zap.internal" domain := strings.TrimPrefix(fulldomain, "*.") type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -927,8 +942,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -962,8 +977,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -1000,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo"}, nil }, }, @@ -1025,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1067,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1110,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1155,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1185,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -1205,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) { } } +type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) + func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { srv := httptest.NewUnstartedServer(http.NewServeMux()) @@ -1308,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) { } } type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -1320,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1350,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1383,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1412,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1442,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1478,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1515,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1561,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1604,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1648,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1691,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1735,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, srv: srv, jwk: jwk, @@ -1757,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1796,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1840,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1883,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1923,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1962,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2007,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2053,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2099,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2143,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2188,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2225,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2252,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) { defer tc.srv.Close() } - if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -2350,3 +2369,34 @@ func Test_serverName(t *testing.T) { }) } } + +func Test_http01ChallengeHost(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + { + name: "dns", + value: "www.example.com", + want: "www.example.com", + }, + { + name: "ipv4", + value: "127.0.0.1", + want: "127.0.0.1", + }, + { + name: "ipv6", + value: "::1", + want: "[::1]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := http01ChallengeHost(tt.value); got != tt.want { + t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/acme/client.go b/acme/client.go new file mode 100644 index 00000000..31f4c975 --- /dev/null +++ b/acme/client.go @@ -0,0 +1,79 @@ +package acme + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" +) + +// Client is the interface used to verify ACME challenges. +type Client interface { + // Get issues an HTTP GET to the specified URL. + Get(url string) (*http.Response, error) + + // LookupTXT returns the DNS TXT records for the given domain name. + LookupTxt(name string) ([]string, error) + + // TLSDial connects to the given network address using net.Dialer and then + // initiates a TLS handshake, returning the resulting TLS connection. + TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +type clientKey struct{} + +// NewClientContext adds the given client to the context. +func NewClientContext(ctx context.Context, c Client) context.Context { + return context.WithValue(ctx, clientKey{}, c) +} + +// ClientFromContext returns the current client from the given context. +func ClientFromContext(ctx context.Context) (c Client, ok bool) { + c, ok = ctx.Value(clientKey{}).(Client) + return +} + +// MustClientFromContext returns the current client from the given context. It will +// return a new instance of the client if it does not exist. +func MustClientFromContext(ctx context.Context) Client { + c, ok := ClientFromContext(ctx) + if !ok { + return NewClient() + } + return c +} + +type client struct { + http *http.Client + dialer *net.Dialer +} + +// NewClient returns an implementation of Client for verifying ACME challenges. +func NewClient() Client { + return &client{ + http: &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + dialer: &net.Dialer{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *client) Get(url string) (*http.Response, error) { + return c.http.Get(url) +} + +func (c *client) LookupTxt(name string) ([]string, error) { + return net.LookupTXT(name) +} + +func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(c.dialer, network, addr, config) +} diff --git a/acme/common.go b/acme/common.go index 0c9e83dc..3054abe1 100644 --- a/acme/common.go +++ b/acme/common.go @@ -9,27 +9,66 @@ import ( "github.com/smallstep/certificates/authority/provisioner" ) +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock Clock + // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error LoadProvisionerByName(string) (provisioner.Interface, error) } -// Clock that returns time in UTC rounded to seconds. -type Clock struct{} +// NewContext adds the given acme components to the context. +func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { + ctx = NewDatabaseContext(ctx, db) + ctx = NewClientContext(ctx, client) + ctx = NewLinkerContext(ctx, linker) + // Prerequisite checker is optional. + if fn != nil { + ctx = NewPrerequisitesCheckerContext(ctx, fn) + } + return ctx +} -// Now returns the UTC time rounded to seconds. -func (c *Clock) Now() time.Time { - return time.Now().UTC().Truncate(time.Second) +// PrerequisitesChecker is a function that checks if all prerequisites for +// serving ACME are met by the CA configuration. +type PrerequisitesChecker func(ctx context.Context) (bool, error) + +// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns +// always true. +func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) { + return true, nil } -var clock Clock +type prerequisitesKey struct{} + +// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the +// context. +func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context { + return context.WithValue(ctx, prerequisitesKey{}, fn) +} + +// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the +// context. +func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) { + fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker) + return fn, ok && fn != nil +} // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { + AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error GetID() string @@ -38,16 +77,40 @@ type Provisioner interface { GetOptions() *provisioner.Options } +type provisionerKey struct{} + +// NewProvisionerContext adds the given provisioner to the context. +func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, v) +} + +// ProvisionerFromContext returns the current provisioner from the given context. +func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) { + v, ok = ctx.Value(provisionerKey{}).(Provisioner) + return +} + +// MustLinkerFromContext returns the current provisioner from the given context. +// It will panic if it's not in the context. +func MustProvisionerFromContext(ctx context.Context) Provisioner { + if v, ok := ProvisionerFromContext(ctx); !ok { + panic("acme provisioner is not the context") + } else { + return v + } +} + // MockProvisioner for testing type MockProvisioner struct { - Mret1 interface{} - Merr error - MgetID func() string - MgetName func() string - MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) - MauthorizeRevoke func(ctx context.Context, token string) error - MdefaultTLSCertDuration func() time.Duration - MgetOptions func() *provisioner.Options + Mret1 interface{} + Merr error + MgetID func() string + MgetName func() string + MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error + MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) + MauthorizeRevoke func(ctx context.Context, token string) error + MdefaultTLSCertDuration func() time.Duration + MgetOptions func() *provisioner.Options } // GetName mock @@ -58,6 +121,14 @@ func (m *MockProvisioner) GetName() string { return m.Mret1.(string) } +// AuthorizeOrderIdentifiers mock +func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { + if m.MauthorizeOrderIdentifier != nil { + return m.MauthorizeOrderIdentifier(ctx, identifier) + } + return m.Merr +} + // AuthorizeSign mock func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.MauthorizeSign != nil { diff --git a/acme/db.go b/acme/db.go index 412276fd..d7c9d5f4 100644 --- a/acme/db.go +++ b/acme/db.go @@ -23,6 +23,7 @@ type DB interface { GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error @@ -48,6 +49,29 @@ type DB interface { UpdateOrder(ctx context.Context, o *Order) error } +type dbKey struct{} + +// NewDatabaseContext adds the given acme database to the context. +func NewDatabaseContext(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// DatabaseFromContext returns the current acme database from the given context. +func DatabaseFromContext(ctx context.Context) (db DB, ok bool) { + db, ok = ctx.Value(dbKey{}).(DB) + return +} + +// MustDatabaseFromContext returns the current database from the given context. +// It will panic if it's not in the context. +func MustDatabaseFromContext(ctx context.Context) DB { + if db, ok := DatabaseFromContext(ctx); !ok { + panic("acme database is not in the context") + } else { + return db + } +} + // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { @@ -60,6 +84,7 @@ type MockDB struct { MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error @@ -168,6 +193,16 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision return m.MockRet1.(*ExternalAccountKey), m.MockError } +// GetExternalAccountKeyByAccountID mock +func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) { + if m.MockGetExternalAccountKeyByAccountID != nil { + return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*ExternalAccountKey), m.MockError +} + // DeleteExternalAccountKey mock func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { if m.MockDeleteExternalAccountKey != nil { diff --git a/acme/db/nosql/eab.go b/acme/db/nosql/eab.go index f9a24daf..e87aa9bc 100644 --- a/acme/db/nosql/eab.go +++ b/acme/db/nosql/eab.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" nosqlDB "github.com/smallstep/nosql" ) @@ -23,7 +24,7 @@ type dbExternalAccountKey struct { ProvisionerID string `json:"provisionerID"` Reference string `json:"reference"` AccountID string `json:"accountID,omitempty"` - KeyBytes []byte `json:"key"` + HmacKey []byte `json:"key"` CreatedAt time.Time `json:"createdAt"` BoundAt time.Time `json:"boundAt"` } @@ -72,7 +73,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer ID: keyID, ProvisionerID: provisionerID, Reference: reference, - KeyBytes: random, + HmacKey: random, CreatedAt: clock.Now(), } @@ -99,7 +100,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer ProvisionerID: dbeak.ProvisionerID, Reference: dbeak.Reference, AccountID: dbeak.AccountID, - KeyBytes: dbeak.KeyBytes, + HmacKey: dbeak.HmacKey, CreatedAt: dbeak.CreatedAt, BoundAt: dbeak.BoundAt, }, nil @@ -124,7 +125,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st ProvisionerID: dbeak.ProvisionerID, Reference: dbeak.Reference, AccountID: dbeak.AccountID, - KeyBytes: dbeak.KeyBytes, + HmacKey: dbeak.HmacKey, CreatedAt: dbeak.CreatedAt, BoundAt: dbeak.BoundAt, }, nil @@ -191,7 +192,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor } keys = append(keys, &acme.ExternalAccountKey{ ID: eak.ID, - KeyBytes: eak.KeyBytes, + HmacKey: eak.HmacKey, ProvisionerID: eak.ProvisionerID, Reference: eak.Reference, AccountID: eak.AccountID, @@ -226,6 +227,10 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID) } +func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + return nil, nil +} + func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { externalAccountKeyMutex.Lock() defer externalAccountKeyMutex.Unlock() @@ -252,7 +257,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string ProvisionerID: eak.ProvisionerID, Reference: eak.Reference, AccountID: eak.AccountID, - KeyBytes: eak.KeyBytes, + HmacKey: eak.HmacKey, CreatedAt: eak.CreatedAt, BoundAt: eak.BoundAt, } diff --git a/acme/db/nosql/eab_test.go b/acme/db/nosql/eab_test.go index 568500e9..525afa72 100644 --- a/acme/db/nosql/eab_test.go +++ b/acme/db/nosql/eab_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" certdb "github.com/smallstep/certificates/db" @@ -32,7 +33,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -108,7 +109,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) { } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbeak.ID, tc.dbeak.ID) - assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes) + assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey) assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID) assert.Equals(t, dbeak.Reference, tc.dbeak.Reference) assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt) @@ -136,7 +137,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -154,7 +155,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, } @@ -179,7 +180,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -197,7 +198,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"), @@ -225,7 +226,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { } } else if assert.Nil(t, tc.err) { assert.Equals(t, eak.ID, tc.eak.ID) - assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, eak.HmacKey, tc.eak.HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.Reference, tc.eak.Reference) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) @@ -255,7 +256,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -288,7 +289,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, err: nil, @@ -392,7 +393,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) { assert.Equals(t, eak.AccountID, tc.eak.AccountID) assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) - assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, eak.HmacKey, tc.eak.HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.Reference, tc.eak.Reference) } @@ -420,7 +421,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b1, err := json.Marshal(dbeak1) @@ -430,7 +431,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b2, err := json.Marshal(dbeak2) @@ -440,7 +441,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b3, err := json.Marshal(dbeak3) @@ -513,7 +514,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, { @@ -521,7 +522,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, }, @@ -598,7 +599,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { assert.Equals(t, "", nextCursor) for i, eak := range eaks { assert.Equals(t, eak.ID, tc.eaks[i].ID) - assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes) + assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID) assert.Equals(t, eak.Reference, tc.eaks[i].Reference) assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt) @@ -627,7 +628,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -707,7 +708,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -730,7 +731,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -780,7 +781,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -830,7 +831,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -953,7 +954,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) { assert.Equals(t, string(key), dbeak.ID) assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID) assert.Equals(t, eak.Reference, dbeak.Reference) - assert.Equals(t, 32, len(dbeak.KeyBytes)) + assert.Equals(t, 32, len(dbeak.HmacKey)) assert.False(t, dbeak.CreatedAt.IsZero()) assert.Equals(t, dbeak.AccountID, eak.AccountID) assert.True(t, dbeak.BoundAt.IsZero()) @@ -1078,7 +1079,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -1096,7 +1097,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } return test{ @@ -1120,7 +1121,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { assert.Equals(t, dbNew.AccountID, dbeak.AccountID) assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt) assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt) - assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes) + assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey) return nu, true, nil }, }, @@ -1148,7 +1149,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) @@ -1174,7 +1175,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) @@ -1200,7 +1201,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) @@ -1237,7 +1238,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { assert.Equals(t, dbeak.AccountID, tc.eak.AccountID) assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt) - assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey) } }) } diff --git a/acme/api/linker.go b/acme/linker.go similarity index 59% rename from acme/api/linker.go rename to acme/linker.go index a605ffc3..bddc21f1 100644 --- a/acme/api/linker.go +++ b/acme/linker.go @@ -1,100 +1,19 @@ -package api +package acme import ( "context" "fmt" "net" + "net/http" "net/url" "strings" - "github.com/smallstep/certificates/acme" + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" ) -// NewLinker returns a new Directory type. -func NewLinker(dns, prefix string) Linker { - _, _, err := net.SplitHostPort(dns) - if err != nil && strings.Contains(err.Error(), "too many colons in address") { - // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 - // in case a port was appended to this wrong format, we try to extract the port, then check if it's - // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of - // these cases, then the input dns is not changed. - lastIndex := strings.LastIndex(dns, ":") - hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] - if ip := net.ParseIP(hostPart); ip != nil { - dns = "[" + hostPart + "]:" + portPart - } else if ip := net.ParseIP(dns); ip != nil { - dns = "[" + dns + "]" - } - } - return &linker{prefix: prefix, dns: dns} -} - -// Linker interface for generating links for ACME resources. -type Linker interface { - GetLink(ctx context.Context, typ LinkType, inputs ...string) string - GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string - - LinkOrder(ctx context.Context, o *acme.Order) - LinkAccount(ctx context.Context, o *acme.Account) - LinkChallenge(ctx context.Context, o *acme.Challenge, azID string) - LinkAuthorization(ctx context.Context, o *acme.Authorization) - LinkOrdersByAccountID(ctx context.Context, orders []string) -} - -// linker generates ACME links. -type linker struct { - prefix string - dns string -} - -func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { - switch typ { - case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: - return fmt.Sprintf("/%s/%s", provisionerName, typ) - case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: - return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) - case ChallengeLinkType: - return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) - case OrdersByAccountLinkType: - return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) - case FinalizeLinkType: - return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) - default: - return "" - } -} - -// GetLink is a helper for GetLinkExplicit -func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { - var ( - provName string - baseURL = baseURLFromContext(ctx) - u = url.URL{} - ) - if p, err := provisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - if baseURL != nil { - u = *baseURL - } - - u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...) - - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } - - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = l.dns - } - - u.Path = l.prefix + u.Path - return u.String() -} - // LinkType captures the link type. type LinkType int @@ -160,8 +79,155 @@ func (l LinkType) String() string { } } +func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + return fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + default: + return "" + } +} + +// NewLinker returns a new Directory type. +func NewLinker(dns, prefix string) Linker { + _, _, err := net.SplitHostPort(dns) + if err != nil && strings.Contains(err.Error(), "too many colons in address") { + // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + // in case a port was appended to this wrong format, we try to extract the port, then check if it's + // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of + // these cases, then the input dns is not changed. + lastIndex := strings.LastIndex(dns, ":") + hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] + if ip := net.ParseIP(hostPart); ip != nil { + dns = "[" + hostPart + "]:" + portPart + } else if ip := net.ParseIP(dns); ip != nil { + dns = "[" + dns + "]" + } + } + return &linker{prefix: prefix, dns: dns} +} + +// Linker interface for generating links for ACME resources. +type Linker interface { + GetLink(ctx context.Context, typ LinkType, inputs ...string) string + Middleware(http.Handler) http.Handler + LinkOrder(ctx context.Context, o *Order) + LinkAccount(ctx context.Context, o *Account) + LinkChallenge(ctx context.Context, o *Challenge, azID string) + LinkAuthorization(ctx context.Context, o *Authorization) + LinkOrdersByAccountID(ctx context.Context, orders []string) +} + +type linkerKey struct{} + +// NewLinkerContext adds the given linker to the context. +func NewLinkerContext(ctx context.Context, v Linker) context.Context { + return context.WithValue(ctx, linkerKey{}, v) +} + +// LinkerFromContext returns the current linker from the given context. +func LinkerFromContext(ctx context.Context) (v Linker, ok bool) { + v, ok = ctx.Value(linkerKey{}).(Linker) + return +} + +// MustLinkerFromContext returns the current linker from the given context. It +// will panic if it's not in the context. +func MustLinkerFromContext(ctx context.Context) Linker { + if v, ok := LinkerFromContext(ctx); !ok { + panic("acme linker is not the context") + } else { + return v + } +} + +type baseURLKey struct{} + +func newBaseURLContext(ctx context.Context, r *http.Request) context.Context { + var u *url.URL + if r.Host != "" { + u = &url.URL{Scheme: "https", Host: r.Host} + } + return context.WithValue(ctx, baseURLKey{}, u) +} + +func baseURLFromContext(ctx context.Context) *url.URL { + if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok { + return u + } + return nil +} + +// linker generates ACME links. +type linker struct { + prefix string + dns string +} + +// Middleware gets the provisioner and current url from the request and sets +// them in the context so we can use the linker to create ACME links. +func (l *linker) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add base url to the context. + ctx := newBaseURLContext(r.Context(), r) + + // Add provisioner to the context. + nameEscaped := chi.URLParam(r, "provisionerID") + name, err := url.PathUnescape(nameEscaped) + if err != nil { + render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + return + } + + p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) + if err != nil { + render.Error(w, err) + return + } + + acmeProv, ok := p.(*provisioner.ACME) + if !ok { + render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + return + } + + ctx = NewProvisionerContext(ctx, Provisioner(acmeProv)) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetLink is a helper for GetLinkExplicit. +func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { + var name string + if p, ok := ProvisionerFromContext(ctx); ok { + name = p.GetName() + } + + var u url.URL + if baseURL := baseURLFromContext(ctx); baseURL != nil { + u = *baseURL + } + if u.Scheme == "" { + u.Scheme = "https" + } + if u.Host == "" { + u.Host = l.dns + } + + u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...) + return u.String() +} + // LinkOrder sets the ACME links required by an ACME order. -func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { +func (l *linker) LinkOrder(ctx context.Context, o *Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) @@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { } // LinkAccount sets the ACME links required by an ACME account. -func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { +func (l *linker) LinkAccount(ctx context.Context, acc *Account) { acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. -func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { +func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) { ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. -func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { +func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch, az.ID) } diff --git a/acme/api/linker_test.go b/acme/linker_test.go similarity index 82% rename from acme/api/linker_test.go rename to acme/linker_test.go index 74c2c8b0..b85d1a53 100644 --- a/acme/api/linker_test.go +++ b/acme/linker_test.go @@ -1,21 +1,38 @@ -package api +package acme import ( "context" "fmt" "net/url" "testing" + "time" "github.com/smallstep/assert" - "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" ) -func TestLinker_GetUnescapedPathSuffix(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - linker := NewLinker(dns, prefix) +func mockProvisioner(t *testing.T) Provisioner { + t.Helper() + var defaultDisableRenewal = false + + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(provisioner.Config{Claims: provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + }}); err != nil { + fmt.Printf("%v", err) + } + return p +} - getPath := linker.GetUnescapedPathSuffix +func TestGetUnescapedPathSuffix(t *testing.T) { + getPath := GetUnescapedPathSuffix assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") @@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) { } func TestLinker_DNS(t *testing.T) { - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) type test struct { name string dns string @@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) { linker := NewLinker(dns, prefix) id := "1234" - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) // No provisioner and no BaseURL from request assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) // Provisioner: yes, BaseURL: no - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) // Provisioner: no, BaseURL: yes - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) @@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) oid := "orderID" certID := "certID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - o *acme.Order - validate func(o *acme.Order) + o *Order + validate func(o *Order) } var tests = map[string]test{ "no-authz-and-no-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.CertificateURL, "") }, }, "one-authz-and-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) { }, }, "many-authz": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo", "bar", "zap"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) accID := "accountID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - a *acme.Account - validate func(o *acme.Account) + a *Account + validate func(o *Account) } var tests = map[string]test{ "ok": { - a: &acme.Account{ + a: &Account{ ID: accID, }, - validate: func(a *acme.Account) { + validate: func(a *Account) { assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) }, }, @@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID := "chID" azID := "azID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - ch *acme.Challenge - validate func(o *acme.Challenge) + ch *Challenge + validate func(o *Challenge) } var tests = map[string]test{ "ok": { - ch: &acme.Challenge{ + ch: &Challenge{ ID: chID, }, - validate: func(ch *acme.Challenge) { + validate: func(ch *Challenge) { assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) }, }, @@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID0 := "chID-0" chID1 := "chID-1" @@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - az *acme.Authorization - validate func(o *acme.Authorization) + az *Authorization + validate func(o *Authorization) } var tests = map[string]test{ "ok": { - az: &acme.Authorization{ + az: &Authorization{ ID: azID, - Challenges: []*acme.Challenge{ + Challenges: []*Challenge{ {ID: chID0}, {ID: chID1}, {ID: chID2}, }, }, - validate: func(az *acme.Authorization) { + validate: func(az *Authorization) { assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) @@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) diff --git a/acme/order_test.go b/acme/order_test.go index 493b40b7..f1f28e40 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -268,6 +268,7 @@ func TestOrder_UpdateStatus(t *testing.T) { type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} err error @@ -282,6 +283,13 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { + if m.areSANsAllowed != nil { + return m.areSANsAllowed(ctx, sans) + } + return m.err +} + func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) { if m.loadProvisionerByName != nil { return m.loadProvisionerByName(name) diff --git a/api/api.go b/api/api.go index 014a89e9..bf5c8632 100644 --- a/api/api.go +++ b/api/api.go @@ -35,7 +35,6 @@ type Authority interface { SSHAuthority // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) - AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) @@ -53,6 +52,11 @@ type Authority interface { GetCertificateRevocationList() ([]byte, error) } +// mustAuthority will be replaced on unit tests. +var mustAuthority = func(ctx context.Context) Authority { + return authority.MustFromContext(ctx) +} + // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -244,49 +248,54 @@ type caHandler struct { Authority Authority } -// New creates a new RouterHandler with the CA endpoints. -func New(auth Authority) RouterHandler { - return &caHandler{ - Authority: auth, - } +// Route configures the http request router. +func (h *caHandler) Route(r Router) { + Route(r) } -func (h *caHandler) Route(r Router) { - r.MethodFunc("GET", "/version", h.Version) - r.MethodFunc("GET", "/health", h.Health) - r.MethodFunc("GET", "/root/{sha}", h.Root) - r.MethodFunc("POST", "/sign", h.Sign) - r.MethodFunc("POST", "/renew", h.Renew) - r.MethodFunc("POST", "/rekey", h.Rekey) - r.MethodFunc("POST", "/revoke", h.Revoke) - r.MethodFunc("GET", "/crl", h.CRL) - r.MethodFunc("GET", "/provisioners", h.Provisioners) - r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) - r.MethodFunc("GET", "/roots", h.Roots) - r.MethodFunc("GET", "/roots.pem", h.RootsPEM) - r.MethodFunc("GET", "/federation", h.Federation) +// New creates a new RouterHandler with the CA endpoints. +// +// Deprecated: Use api.Route(r Router) +func New(auth Authority) RouterHandler { + return &caHandler{} +} + +func Route(r Router) { + r.MethodFunc("GET", "/version", Version) + r.MethodFunc("GET", "/health", Health) + r.MethodFunc("GET", "/root/{sha}", Root) + r.MethodFunc("POST", "/sign", Sign) + r.MethodFunc("POST", "/renew", Renew) + r.MethodFunc("POST", "/rekey", Rekey) + r.MethodFunc("POST", "/revoke", Revoke) + r.MethodFunc("GET", "/crl", CRL) + r.MethodFunc("GET", "/provisioners", Provisioners) + r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey) + r.MethodFunc("GET", "/roots", Roots) + r.MethodFunc("GET", "/roots.pem", RootsPEM) + r.MethodFunc("GET", "/federation", Federation) // SSH CA - r.MethodFunc("POST", "/ssh/sign", h.SSHSign) - r.MethodFunc("POST", "/ssh/renew", h.SSHRenew) - r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke) - r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey) - r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) - r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) - r.MethodFunc("POST", "/ssh/config", h.SSHConfig) - r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) - r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) - r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts) - r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) + r.MethodFunc("POST", "/ssh/sign", SSHSign) + r.MethodFunc("POST", "/ssh/renew", SSHRenew) + r.MethodFunc("POST", "/ssh/revoke", SSHRevoke) + r.MethodFunc("POST", "/ssh/rekey", SSHRekey) + r.MethodFunc("GET", "/ssh/roots", SSHRoots) + r.MethodFunc("GET", "/ssh/federation", SSHFederation) + r.MethodFunc("POST", "/ssh/config", SSHConfig) + r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig) + r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost) + r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts) + r.MethodFunc("POST", "/ssh/bastion", SSHBastion) // For compatibility with old code: - r.MethodFunc("POST", "/re-sign", h.Renew) - r.MethodFunc("POST", "/sign-ssh", h.SSHSign) - r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) + r.MethodFunc("POST", "/re-sign", Renew) + r.MethodFunc("POST", "/sign-ssh", SSHSign) + r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts) } // Version is an HTTP handler that returns the version of the server. -func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { - v := h.Authority.Version() +func Version(w http.ResponseWriter, r *http.Request) { + v := mustAuthority(r.Context()).Version() render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, @@ -294,17 +303,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { } // Health is an HTTP handler that returns the status of the server. -func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { +func Health(w http.ResponseWriter, r *http.Request) { render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root // certificate for the given SHA256. -func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { +func Root(w http.ResponseWriter, r *http.Request) { sha := chi.URLParam(r, "sha") sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) // Load root certificate with the - cert, err := h.Authority.Root(sum) + cert, err := mustAuthority(r.Context()).Root(sum) if err != nil { render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return @@ -322,18 +331,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { } // Provisioners returns the list of provisioners configured in the authority. -func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { +func Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { render.Error(w, err) return } - p, next, err := h.Authority.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return } + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, @@ -341,19 +351,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { } // ProvisionerKey returns the encrypted key of a provisioner by it's key id. -func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { +func ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") - key, err := h.Authority.GetEncryptedKey(kid) + key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) if err != nil { render.Error(w, errs.NotFoundErr(err)) return } + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. -func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func Roots(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return @@ -370,8 +381,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { } // RootsPEM returns all the root certificates for the CA in PEM format. -func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func RootsPEM(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -393,8 +404,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { } // Federation returns all the public certificates in the federation. -func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { - federated, err := h.Authority.GetFederation() +func Federation(w http.ResponseWriter, r *http.Request) { + federated, err := mustAuthority(r.Context()).GetFederation() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return diff --git a/api/api_test.go b/api/api_test.go index 06b165bf..485244b9 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest { return csr } +func mockMustAuthority(t *testing.T, a Authority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) Authority { + return a + } +} + type mockAuthority struct { ret1, ret2 interface{} err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) @@ -207,12 +218,8 @@ func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) { // TODO: remove once Authorize is deprecated. func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) + if m.authorize != nil { + return m.authorize(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } @@ -793,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) { } } -func Test_caHandler_Health(t *testing.T) { +func Test_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() - h := New(&mockAuthority{}).(*caHandler) - h.Health(w, req) + Health(w, req) res := w.Result() if res.StatusCode != 200 { @@ -815,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) { } } -func Test_caHandler_Root(t *testing.T) { +func Test_Root(t *testing.T) { tests := []struct { name string root *x509.Certificate @@ -836,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err}) w := httptest.NewRecorder() - h.Root(w, req) + Root(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -859,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) { } } -func Test_caHandler_Sign(t *testing.T) { +func Test_Sign(t *testing.T) { csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, @@ -900,18 +906,18 @@ func Test_caHandler_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) w := httptest.NewRecorder() - h.Sign(logging.NewResponseLogger(w), req) + Sign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -932,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) { } } -func Test_caHandler_Renew(t *testing.T) { +func Test_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1022,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) @@ -1043,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) { getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls req.Header = tt.header w := httptest.NewRecorder() - h.Renew(logging.NewResponseLogger(w), req) + Renew(logging.NewResponseLogger(w), req) res := w.Result() defer res.Body.Close() @@ -1077,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) { } } -func Test_caHandler_Rekey(t *testing.T) { +func Test_Rekey(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1108,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req.TLS = tt.tls w := httptest.NewRecorder() - h.Rekey(logging.NewResponseLogger(w), req) + Rekey(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1138,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) { } } -func Test_caHandler_Provisioners(t *testing.T) { +func Test_Provisioners(t *testing.T) { type fields struct { Authority Authority } @@ -1204,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) { assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &caHandler{ - Authority: tt.fields.Authority, - } - h.Provisioners(tt.args.w, tt.args.r) + mockMustAuthority(t, tt.fields.Authority) + Provisioners(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() @@ -1242,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) { } } -func Test_caHandler_ProvisionerKey(t *testing.T) { +func Test_ProvisionerKey(t *testing.T) { type fields struct { Authority Authority } @@ -1274,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &caHandler{ - Authority: tt.fields.Authority, - } - h.ProvisionerKey(tt.args.w, tt.args.r) + mockMustAuthority(t, tt.fields.Authority) + ProvisionerKey(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() @@ -1302,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } } -func Test_caHandler_Roots(t *testing.T) { +func Test_Roots(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1323,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() - h.Roots(w, req) + Roots(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1364,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err}) req := httptest.NewRequest("GET", "https://example.com/roots", nil) w := httptest.NewRecorder() - h.RootsPEM(w, req) + RootsPEM(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1388,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) { } } -func Test_caHandler_Federation(t *testing.T) { +func Test_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1409,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() - h.Federation(w, req) + Federation(w, req) res := w.Result() if res.StatusCode != tt.statusCode { diff --git a/api/read/read.go b/api/read/read.go index de92c5d7..72530b8c 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -3,16 +3,20 @@ package read import ( "encoding/json" + "errors" "io" + "net/http" + "strings" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) // JSON reads JSON from the request body and stores it in the value -// pointed by v. +// pointed to by v. func JSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { return errs.BadRequestErr(err, "error decoding json") @@ -21,11 +25,42 @@ func JSON(r io.Reader, v interface{}) error { } // ProtoJSON reads JSON from the request body and stores it in the value -// pointed by v. +// pointed to by m. func ProtoJSON(r io.Reader, m proto.Message) error { data, err := io.ReadAll(r) if err != nil { return errs.BadRequestErr(err, "error reading request body") } - return protojson.Unmarshal(data, m) + + switch err := protojson.Unmarshal(data, m); { + case errors.Is(err, proto.Error): + return badProtoJSONError(err.Error()) + default: + return err + } +} + +// badProtoJSONError is an error type that is returned by ProtoJSON +// when a proto message cannot be unmarshaled. Usually this is caused +// by an error in the request body. +type badProtoJSONError string + +// Error implements error for badProtoJSONError +func (e badProtoJSONError) Error() string { + return string(e) +} + +// Render implements render.RenderableError for badProtoJSONError +func (e badProtoJSONError) Render(w http.ResponseWriter) { + v := struct { + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + }{ + Type: "badRequest", + Detail: "bad request", + // trim the proto prefix for the message + Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")), + } + render.JSONStatus(w, v, http.StatusBadRequest) } diff --git a/api/read/read_test.go b/api/read/read_test.go index f2eff1bc..72100584 100644 --- a/api/read/read_test.go +++ b/api/read/read_test.go @@ -1,10 +1,21 @@ package read import ( + "encoding/json" + "errors" "io" + "net/http" + "net/http/httptest" "reflect" "strings" "testing" + "testing/iotest" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "go.step.sm/linkedca" "github.com/smallstep/certificates/errs" ) @@ -44,3 +55,110 @@ func TestJSON(t *testing.T) { }) } } + +func TestProtoJSON(t *testing.T) { + + p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import? + + type args struct { + r io.Reader + m proto.Message + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "fail/io.ReadAll", + args: args{ + r: iotest.ErrReader(errors.New("read error")), + m: p, + }, + wantErr: true, + }, + { + name: "fail/proto", + args: args{ + r: strings.NewReader(`{?}`), + m: p, + }, + wantErr: true, + }, + { + name: "ok", + args: args{ + r: strings.NewReader(`{"x509":{}}`), + m: p, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ProtoJSON(tt.args.r, tt.args.m) + if (err != nil) != tt.wantErr { + t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + switch err.(type) { + case badProtoJSONError: + assert.Contains(t, err.Error(), "syntax error") + case *errs.Error: + var ee *errs.Error + if errors.As(err, &ee) { + assert.Equal(t, http.StatusBadRequest, ee.Status) + } + } + return + } + + assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m)) + assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m)) + }) + } +} + +func Test_badProtoJSONError_Render(t *testing.T) { + tests := []struct { + name string + e badProtoJSONError + expected string + }{ + { + name: "bad proto normal space", + e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), + expected: "syntax error (line 1:2): invalid value ?", + }, + { + name: "bad proto non breaking space", + e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), + expected: "syntax error (line 1:2): invalid value ?", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + w := httptest.NewRecorder() + tt.e.Render(w) + res := w.Result() + defer res.Body.Close() + + data, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + v := struct { + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + }{} + + assert.NoError(t, json.Unmarshal(data, &v)) + assert.Equal(t, "badRequest", v.Type) + assert.Equal(t, "bad request", v.Detail) + assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message) + + }) + } +} diff --git a/api/rekey.go b/api/rekey.go index 3116cf74..cda843a3 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error { } // Rekey is similar to renew except that the certificate will be renewed with new key from csr. -func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { +func Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { render.Error(w, errs.BadRequest("missing client certificate")) return @@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { return } - certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) + a := mustAuthority(r.Context()) + certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return @@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/renew.go b/api/renew.go index 9c4bff32..6e9f680f 100644 --- a/api/renew.go +++ b/api/renew.go @@ -16,14 +16,15 @@ const ( // Renew uses the information of certificate in the TLS connection to create a // new one. -func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - cert, err := h.getPeerCertificate(r) +func Renew(w http.ResponseWriter, r *http.Request) { + cert, err := getPeerCertificate(r) if err != nil { render.Error(w, err) return } - certChain, err := h.Authority.Renew(cert) + a := mustAuthority(r.Context()) + certChain, err := a.Renew(cert) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } -func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { +func getPeerCertificate(r *http.Request) (*x509.Certificate, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { return r.TLS.PeerCertificates[0], nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { - return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) + ctx := r.Context() + return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) } } return nil, errs.BadRequest("missing client certificate") diff --git a/api/revoke.go b/api/revoke.go index c9da2c18..aebbb875 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "golang.org/x/crypto/ocsp" @@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. // // TODO: Add CRL and OCSP support. -func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { +func Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { PassiveOnly: body.Passive, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. if len(body.OTT) > 0 { logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } @@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { opts.MTLS = true } - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } diff --git a/api/revoke_test.go b/api/revoke_test.go index 7635ce68..c3fa6ceb 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) { statusCode: http.StatusOK, tls: cs, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { @@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusInternalServerError, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusForbidden, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) { for name, _tc := range tests { tc := _tc(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*caHandler) + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) if tc.tls != nil { req.TLS = tc.tls } w := httptest.NewRecorder() - h.Revoke(logging.NewResponseLogger(w), req) + Revoke(logging.NewResponseLogger(w), req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/api/sign.go b/api/sign.go index b6bfcc8b..f7c3cc5a 100644 --- a/api/sign.go +++ b/api/sign.go @@ -49,7 +49,7 @@ type SignResponse struct { // Sign is an HTTP handler that reads a certificate request and an // one-time-token (ott) from the body and creates a new certificate with the // information in the certificate request. -func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { +func Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { TemplateData: body.TemplateData, } - signOpts, err := h.Authority.AuthorizeSign(body.OTT) + ctx := r.Context() + a := mustAuthority(ctx) + + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return @@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/ssh.go b/api/ssh.go index 3b0de7c1..4bd20495 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -250,7 +250,7 @@ type SSHBastionResponse struct { // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { +func SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -288,13 +288,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) + cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -302,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var addUserCertificate *SSHCertificate if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { - addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) + addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -315,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -327,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return @@ -344,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { // SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. -func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHRoots(r.Context()) +func SSHRoots(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHRoots(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -369,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { // SSHFederation is an HTTP handler that returns the federated SSH public keys // for user and host certificates. -func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHFederation(r.Context()) +func SSHFederation(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHFederation(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -394,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { // SSHConfig is an HTTP handler that returns rendered templates for ssh clients // and servers. -func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { +func SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -405,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) + ctx := r.Context() + ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -426,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. -func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { +func SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -437,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } - exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) + ctx := r.Context() + exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -448,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { } // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. -func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { +func SSHGetHosts(w http.ResponseWriter, r *http.Request) { var cert *x509.Certificate if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { cert = r.TLS.PeerCertificates[0] } - hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) + ctx := r.Context() + hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -465,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { } // SSHBastion provides returns the bastion configured if any. -func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { +func SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -476,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } - bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) + ctx := r.Context() + bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) if err != nil { render.Error(w, errs.InternalServerErr(err)) return diff --git a/api/sshRekey.go b/api/sshRekey.go index 92278950..6c0a5064 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -39,7 +39,7 @@ type SSHRekeyResponse struct { // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { +func SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -59,7 +59,10 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -70,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) + newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return @@ -80,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return diff --git a/api/sshRenew.go b/api/sshRenew.go index 78d16fa6..4e4d0b04 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -37,7 +37,7 @@ type SSHRenewResponse struct { // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { +func SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -51,7 +51,10 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) - _, err := h.Authority.Authorize(ctx, body.OTT) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + + a := mustAuthority(ctx) + _, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -62,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RenewSSH(ctx, oldCert) + newCert, err := a.RenewSSH(ctx, oldCert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return @@ -72,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return @@ -85,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the -func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { +func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, nil } @@ -105,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte cert.NotAfter = notAfter } - certChain, err := h.Authority.Renew(cert) + certChain, err := mustAuthority(r.Context()).Renew(cert) if err != nil { return nil, err } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index a33082cd..d377def9 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { // Revoke supports handful of different methods that revoke a Certificate. // // NOTE: currently only Passive revocation is supported. -func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { +func SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } diff --git a/api/ssh_test.go b/api/ssh_test.go index 88a301f5..57dd6775 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) { } } -func Test_caHandler_SSHSign(t *testing.T) { +func Test_SSHSign(t *testing.T) { user, err := getSignedUserCertificate() assert.FatalError(t, err) host, err := getSignedHostCertificate() @@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + mockMustAuthority(t, &mockAuthority{ + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { @@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) { sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHSign(logging.NewResponseLogger(w), req) + SSHSign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } } -func Test_caHandler_SSHRoots(t *testing.T) { +func Test_SSHRoots(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) w := httptest.NewRecorder() - h.SSHRoots(logging.NewResponseLogger(w), req) + SSHRoots(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { } } -func Test_caHandler_SSHFederation(t *testing.T) { +func Test_SSHFederation(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) w := httptest.NewRecorder() - h.SSHFederation(logging.NewResponseLogger(w), req) + SSHFederation(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { } } -func Test_caHandler_SSHConfig(t *testing.T) { +func Test_SSHConfig(t *testing.T) { userOutput := []templates.Output{ {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, @@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHConfig(logging.NewResponseLogger(w), req) + SSHConfig(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { } } -func Test_caHandler_SSHCheckHost(t *testing.T) { +func Test_SSHCheckHost(t *testing.T) { tests := []struct { name string req string @@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHCheckHost(logging.NewResponseLogger(w), req) + SSHCheckHost(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } } -func Test_caHandler_SSHGetHosts(t *testing.T) { +func Test_SSHGetHosts(t *testing.T) { hosts := []authority.Host{ {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, @@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) w := httptest.NewRecorder() - h.SSHGetHosts(logging.NewResponseLogger(w), req) + SSHGetHosts(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } } -func Test_caHandler_SSHBastion(t *testing.T) { +func Test_SSHBastion(t *testing.T) { bastion := &authority.Bastion{ Hostname: "bastion.local", } @@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHBastion(logging.NewResponseLogger(w), req) + SSHBastion(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 21a7229d..db393e9a 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -1,22 +1,15 @@ package api import ( - "context" "fmt" "net/http" - "github.com/go-chi/chi" - "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/provisioner" -) - -const ( - // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") ) // CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests @@ -40,78 +33,121 @@ type GetExternalAccountKeysResponse struct { // requireEABEnabled is a middleware that ensures ACME EAB is enabled // before serving requests that act on ACME EAB credentials. -func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { +func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - provName := chi.URLParam(r, "provisionerName") - eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) - if err != nil { - render.Error(w, err) + prov := linkedca.MustProvisionerFromContext(ctx) + + acmeProvisioner := prov.GetDetails().GetACME() + if acmeProvisioner == nil { + render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName())) return } - if !eabEnabled { - render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + + if !acmeProvisioner.RequireEab { + render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName())) return } - ctx = context.WithValue(ctx, provisionerContextKey, prov) - next(w, r.WithContext(ctx)) - } -} - -// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME -// provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { - var ( - p provisioner.Interface - err error - ) - if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil { - return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName) - } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) - if err != nil { - return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID()) + next(w, r) } - - details := prov.GetDetails() - if details == nil { - return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID()) - } - - acmeProvisioner := details.GetACME() - if acmeProvisioner == nil { - return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID()) - } - - return acmeProvisioner.GetRequireEab(), prov, nil } -type acmeAdminResponderInterface interface { +// ACMEAdminResponder is responsible for writing ACME admin responses +type ACMEAdminResponder interface { GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) } -// ACMEAdminResponder is responsible for writing ACME admin responses -type ACMEAdminResponder struct{} +// acmeAdminResponder implements ACMEAdminResponder. +type acmeAdminResponder struct{} // NewACMEAdminResponder returns a new ACMEAdminResponder -func NewACMEAdminResponder() *ACMEAdminResponder { - return &ACMEAdminResponder{} +func NewACMEAdminResponder() ACMEAdminResponder { + return &acmeAdminResponder{} } // GetExternalAccountKeys writes the response for the EAB keys GET endpoint -func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint -func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint -func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } + +func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey { + + if k == nil { + return nil + } + + eak := &linkedca.EABKey{ + Id: k.ID, + HmacKey: k.HmacKey, + Provisioner: k.ProvisionerID, + Reference: k.Reference, + Account: k.AccountID, + CreatedAt: timestamppb.New(k.CreatedAt), + BoundAt: timestamppb.New(k.BoundAt), + } + + if k.Policy != nil { + eak.Policy = &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{}, + Deny: &linkedca.X509Names{}, + }, + } + eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames + eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges + eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames + eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges + eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames + } + + return eak +} + +func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey { + if k == nil { + return nil + } + + eak := &acme.ExternalAccountKey{ + ID: k.Id, + ProvisionerID: k.Provisioner, + Reference: k.Reference, + AccountID: k.Account, + HmacKey: k.HmacKey, + CreatedAt: k.CreatedAt.AsTime(), + BoundAt: k.BoundAt.AsTime(), + } + + if policy := k.GetPolicy(); policy != nil { + eak.Policy = &acme.Policy{} + if x509 := policy.GetX509(); x509 != nil { + eak.Policy.X509 = acme.X509Policy{} + if allow := x509.GetAllow(); allow != nil { + eak.Policy.X509.Allowed = acme.PolicyNames{} + eak.Policy.X509.Allowed.DNSNames = allow.Dns + eak.Policy.X509.Allowed.IPRanges = allow.Ips + } + if deny := x509.GetDeny(); deny != nil { + eak.Policy.X509.Denied = acme.PolicyNames{} + eak.Policy.X509.Denied.DNSNames = deny.Dns + eak.Policy.X509.Denied.IPRanges = deny.Ips + } + eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames + } + } + + return eak +} diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 6ffe1418..6d478145 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -4,20 +4,24 @@ import ( "bytes" "context" "encoding/json" - "errors" "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" + "time" "github.com/go-chi/chi" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/admin" ) func readProtoJSON(r io.ReadCloser, m proto.Message) error { @@ -29,109 +33,90 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } +func mockMustAuthority(t *testing.T, a adminAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) adminAuthority { + return a + } +} + func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context - adminDB admin.DB - auth adminAuthority - next nextHTTP + next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ - "fail/h.provisionerHasEABEnabled": func(t *testing.T) test { - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("provisionerName", "provName") - ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return nil, errors.New("force") - }, + "fail/prov.GetDetails": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", } - err := admin.NewErrorISE("error loading provisioner provName: force") - err.Message = "error loading provisioner provName: force" + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") + err.Message = "error getting ACME details for provisioner 'provName'" return test{ ctx: ctx, - auth: auth, err: err, statusCode: 500, } }, - "ok/eab-disabled": func(t *testing.T) test { - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("provisionerName", "provName") - ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, + "fail/prov.GetDetails.GetACME": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{}, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") + err.Message = "error getting ACME details for provisioner 'provName'" + return test{ + ctx: ctx, + err: err, + statusCode: 500, } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: false, - }, - }, + }, + "ok/eab-disabled": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, }, - }, nil + }, }, } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") - err.Message = "ACME EAB not enabled for provisioner provName" + err.Message = "ACME EAB not enabled for provisioner 'provName'" return test{ ctx: ctx, - auth: auth, - adminDB: db, err: err, statusCode: 400, } }, "ok/eab-enabled": func(t *testing.T) test { - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("provisionerName", "provName") - ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: true, - }, - }, - }, - }, nil + }, }, } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, - auth: auth, - adminDB: db, + ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, @@ -143,16 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - - req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup - req = req.WithContext(tc.ctx) + req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx) w := httptest.NewRecorder() - h.requireEABEnabled(tc.next)(w, req) + requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -176,216 +154,6 @@ func TestHandler_requireEABEnabled(t *testing.T) { } } -func TestHandler_provisionerHasEABEnabled(t *testing.T) { - type test struct { - adminDB admin.DB - auth adminAuthority - provisionerName string - want bool - err *admin.Error - } - var tests = map[string]func(t *testing.T) test{ - "fail/auth.LoadProvisionerByName": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return nil, errors.New("force") - }, - } - return test{ - auth: auth, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "fail/db.GetProvisioner": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return nil, errors.New("force") - }, - } - return test{ - auth: auth, - adminDB: db, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "fail/prov.GetDetails": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: nil, - }, nil - }, - } - return test{ - auth: auth, - adminDB: db, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "fail/details.GetACME": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: nil, - }, - }, - }, nil - }, - } - return test{ - auth: auth, - adminDB: db, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "ok/eab-disabled": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "eab-disabled", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "eab-disabled", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: false, - }, - }, - }, - }, nil - }, - } - return test{ - adminDB: db, - auth: auth, - provisionerName: "eab-disabled", - want: false, - } - }, - "ok/eab-enabled": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "eab-enabled", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "eab-enabled", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: true, - }, - }, - }, - }, nil - }, - } - return test{ - adminDB: db, - auth: auth, - provisionerName: "eab-enabled", - want: true, - } - }, - } - for name, prep := range tests { - tc := prep(t) - t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) - if (err != nil) != (tc.err != nil) { - t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) - return - } - if tc.err != nil { - assert.Type(t, &linkedca.Provisioner{}, prov) - assert.Type(t, &admin.Error{}, err) - adminError, _ := err.(*admin.Error) - assert.Equals(t, tc.err.Type, adminError.Type) - assert.Equals(t, tc.err.Status, adminError.Status) - assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode()) - assert.Equals(t, tc.err.Message, adminError.Message) - assert.Equals(t, tc.err.Detail, adminError.Detail) - return - } - if got != tc.want { - t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want) - } - }) - } -} - func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { type fields struct { Reference string @@ -585,3 +353,206 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) { }) } } + +func Test_eakToLinked(t *testing.T) { + tests := []struct { + name string + k *acme.ExternalAccountKey + want *linkedca.EABKey + }{ + { + name: "no-key", + k: nil, + want: nil, + }, + { + name: "no-policy", + k: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: nil, + }, + want: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: nil, + }, + }, + { + name: "with-policy", + k: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/24"}, + }, + Denied: acme.PolicyNames{ + DNSNames: []string{"badhost.local"}, + IPRanges: []string{"10.0.0.30"}, + }, + }, + }, + }, + want: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/24"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"10.0.0.30"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) { + t.Errorf("eakToLinked() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_linkedEAKToCertificates(t *testing.T) { + tests := []struct { + name string + k *linkedca.EABKey + want *acme.ExternalAccountKey + }{ + { + name: "no-key", + k: nil, + want: nil, + }, + { + name: "no-policy", + k: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: nil, + }, + want: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: nil, + }, + }, + { + name: "no-x509-policy", + k: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: &linkedca.Policy{}, + }, + want: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: &acme.Policy{}, + }, + }, + { + name: "with-x509-policy", + k: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/24"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"10.0.0.30"}, + }, + AllowWildcardNames: true, + }, + }, + }, + want: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/24"}, + }, + Denied: acme.PolicyNames{ + DNSNames: []string{"badhost.local"}, + IPRanges: []string{"10.0.0.30"}, + }, + AllowWildcardNames: true, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) { + t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 5e4b9c30..c7adced3 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -29,6 +29,10 @@ type adminAuthority interface { LoadProvisionerByID(id string) (provisioner.Interface, error) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error + GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) + CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + RemoveAuthorityPolicy(ctx context.Context) error } // CreateAdminRequest represents the body for a CreateAdmin request. @@ -81,10 +85,10 @@ type DeleteResponse struct { } // GetAdmin returns the requested admin, or an error. -func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { +func GetAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - adm, ok := h.auth.LoadAdminByID(id) + adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) if !ok { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) @@ -94,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { } // GetAdmins returns a segment of admins associated with the authority. -func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { +func GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -102,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { return } - admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) + admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return @@ -114,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { } // CreateAdmin creates a new admin. -func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { +func CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -126,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { return } - p, err := h.auth.LoadProvisionerByName(body.Provisioner) + auth := mustAuthority(r.Context()) + p, err := auth.LoadProvisionerByName(body.Provisioner) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return @@ -137,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { Type: body.Type, } // Store to authority collection. - if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { + if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } @@ -146,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // DeleteAdmin deletes admin. -func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { +func DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { + if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } @@ -158,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { } // UpdateAdmin updates an existing admin. -func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { +func UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -171,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { } id := chi.URLParam(r, "id") - - adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) + auth := mustAuthority(r.Context()) + adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index 8d223b52..ecb95244 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -14,11 +14,13 @@ import ( "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/types/known/timestamppb" ) type mockAdminAuthority struct { @@ -37,6 +39,11 @@ type mockAdminAuthority struct { MockLoadProvisionerByID func(id string) (provisioner.Interface, error) MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error MockRemoveProvisioner func(ctx context.Context, id string) error + + MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) + MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + MockRemoveAuthorityPolicy func(ctx context.Context) error } func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { @@ -130,6 +137,34 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e return m.MockErr } +func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + if m.MockGetAuthorityPolicy != nil { + return m.MockGetAuthorityPolicy(ctx) + } + return m.MockRet1.(*linkedca.Policy), m.MockErr +} + +func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + if m.MockCreateAuthorityPolicy != nil { + return m.MockCreateAuthorityPolicy(ctx, adm, policy) + } + return m.MockRet1.(*linkedca.Policy), m.MockErr +} + +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + if m.MockUpdateAuthorityPolicy != nil { + return m.MockUpdateAuthorityPolicy(ctx, adm, policy) + } + return m.MockRet1.(*linkedca.Policy), m.MockErr +} + +func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error { + if m.MockRemoveAuthorityPolicy != nil { + return m.MockRemoveAuthorityPolicy(ctx) + } + return m.MockErr +} + func TestCreateAdminRequest_Validate(t *testing.T) { type fields struct { Subject string @@ -317,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmin(w, req) + GetAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -456,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmins(w, req) + GetAdmins(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -640,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateAdmin(w, req) + CreateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -732,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteAdmin(w, req) + DeleteAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -877,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.UpdateAdmin(w, req) + UpdateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 99e74c88..1e5919ce 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -1,56 +1,117 @@ package api import ( + "context" + "net/http" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) // Handler is the Admin API request handler. type Handler struct { - adminDB admin.DB - auth adminAuthority - acmeDB acme.DB - acmeResponder acmeAdminResponderInterface + acmeResponder ACMEAdminResponder + policyResponder PolicyAdminResponder +} + +// Route traffic and implement the Router interface. +// +// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) +func (h *Handler) Route(r api.Router) { + Route(r, h.acmeResponder, h.policyResponder) } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { +// +// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler { return &Handler{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, - acmeResponder: acmeResponder, + acmeResponder: acmeResponder, + policyResponder: policyResponder, } } +var mustAuthority = func(ctx context.Context) adminAuthority { + return authority.MustFromContext(ctx) +} + // Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { - authnz := func(next nextHTTP) nextHTTP { - return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) +func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) { + authnz := func(next http.HandlerFunc) http.HandlerFunc { + return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) + } + + enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { + return checkAction(next, true) + } + + disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { + return checkAction(next, false) } - requireEABEnabled := func(next nextHTTP) nextHTTP { - return h.requireEABEnabled(next) + acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(loadProvisionerByName(requireEABEnabled(next))) + } + + authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(enabledInStandalone(next)) + } + + provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(disabledInStandalone(loadProvisionerByName(next))) + } + + acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next))))) } // Provisioners - r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) - r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) - r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) - r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) - r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) + r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner)) + r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners)) + r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner)) + r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner)) + r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner)) // Admins - r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) - r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) - r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) - r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) - r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) - - // ACME External Account Binding Keys - r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) - r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) + r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin)) + r.MethodFunc("GET", "/admins", authnz(GetAdmins)) + r.MethodFunc("POST", "/admins", authnz(CreateAdmin)) + r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin)) + r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin)) + + // ACME responder + if acmeResponder != nil { + // ACME External Account Binding Keys + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey)) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey)) + } + + // Policy responder + if policyResponder != nil { + // Policy - Authority + r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy)) + r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy)) + r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy)) + r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy)) + + // Policy - Provisioner + r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy)) + r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy)) + r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy)) + r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy)) + + // Policy - ACME Account + r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + } } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index b57dd6eb..780cfb65 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -1,22 +1,26 @@ package api import ( - "context" + "errors" "net/http" + "github.com/go-chi/chi" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/admin/db/nosql" + "github.com/smallstep/certificates/authority/provisioner" ) -type nextHTTP = func(http.ResponseWriter, *http.Request) - // requireAPIEnabled is a middleware that ensures the Administration API // is enabled before servicing requests. -func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { +func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if !h.auth.IsAdminAPIEnabled() { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, - "administration API not enabled")) + if !mustAuthority(r.Context()).IsAdminAPIEnabled() { + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } next(w, r) @@ -24,8 +28,9 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { } // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. -func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { +func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + tok := r.Header.Get("Authorization") if tok == "" { render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, @@ -33,22 +38,111 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - adm, err := h.auth.AuthorizeAdminToken(r, tok) + ctx := r.Context() + adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) if err != nil { render.Error(w, err) return } - ctx := context.WithValue(r.Context(), adminContextKey, adm) + ctx = linkedca.NewContextWithAdmin(ctx, adm) next(w, r.WithContext(ctx)) } } -// ContextKey is the key type for storing and searching for ACME request -// essentials in the context of a request. -type ContextKey string +// loadProvisionerByName is a middleware that searches for a provisioner +// by name and stores it in the context. +func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var ( + p provisioner.Interface + err error + ) -const ( - // adminContextKey account key - adminContextKey = ContextKey("admin") -) + ctx := r.Context() + auth := mustAuthority(ctx) + adminDB := admin.MustFromContext(ctx) + name := chi.URLParam(r, "provisionerName") + + // TODO(hs): distinguish 404 vs. 500 + if p, err = auth.LoadProvisionerByName(name); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + + prov, err := adminDB.GetProvisioner(ctx, p.GetID()) + if err != nil { + render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) + return + } + + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + next(w, r.WithContext(ctx)) + } +} + +// checkAction checks if an action is supported in standalone or not +func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // actions allowed in standalone mode are always supported + if supportedInStandalone { + next(w, r) + return + } + + // when an action is not supported in standalone mode and when + // using a nosql.DB backend, actions are not supported + if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok { + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, + "operation not supported in standalone mode")) + return + } + + // continue to next http handler + next(w, r) + } +} + +// loadExternalAccountKey is a middleware that searches for an ACME +// External Account Key by reference or keyID and stores it in the context. +func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + acmeDB := acme.MustDatabaseFromContext(ctx) + + reference := chi.URLParam(r, "reference") + keyID := chi.URLParam(r, "keyID") + + var ( + eak *acme.ExternalAccountKey + err error + ) + + if keyID != "" { + eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID) + } else { + eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) + } + + if err != nil { + if errors.Is(err, acme.ErrNotFound) { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) + return + } + render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key")) + return + } + + if eak == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) + return + } + + linkedEAK := eakToLinked(eak) + + ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK) + + next(w, r.WithContext(ctx)) + } +} diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 7fb4671a..4684b047 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -4,25 +4,32 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/admin" - "go.step.sm/linkedca" - "google.golang.org/protobuf/types/known/timestamppb" + "github.com/smallstep/certificates/authority/admin/db/nosql" + "github.com/smallstep/certificates/authority/provisioner" ) func TestHandler_requireAPIEnabled(t *testing.T) { type test struct { ctx context.Context auth adminAuthority - next nextHTTP + next http.HandlerFunc err *admin.Error statusCode int } @@ -64,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.requireAPIEnabled(tc.next)(w, req) + requireAPIEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -102,7 +107,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { ctx context.Context auth adminAuthority req *http.Request - next nextHTTP + next http.HandlerFunc err *admin.Error statusCode int } @@ -152,7 +157,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { req.Header["Authorization"] = []string{"token"} createdAt := time.Now() var deletedAt time.Time - admin := &linkedca.Admin{ + adm := &linkedca.Admin{ Id: "adminID", AuthorityId: "authorityID", Subject: "admin", @@ -164,20 +169,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) - return admin, nil + return adm, nil }, } next := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin - adm, ok := a.(*linkedca.Admin) - if !ok { - t.Errorf("expected *linkedca.Admin; got %T", a) - return - } + adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} - if !cmp.Equal(admin, adm, opts...) { - t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...)) + if !cmp.Equal(adm, adm, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...)) } w.Write(nil) // mock response with status 200 } @@ -194,13 +194,459 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, + mockMustAuthority(t, tc.auth) + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + extractAuthorizeTokenAdmin(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return } + }) + } +} - req := tc.req.WithContext(tc.ctx) +func TestHandler_loadProvisionerByName(t *testing.T) { + type test struct { + adminDB admin.DB + auth adminAuthority + ctx context.Context + next http.HandlerFunc + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName") + err.Message = "error loading provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + adminDB: &admin.MockDB{}, + statusCode: 500, + err: err, + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName") + err.Message = "error retrieving provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + adminDB: db, + statusCode: 500, + err: err, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + }, nil + }, + } + return test{ + ctx: ctx, + auth: auth, + adminDB: db, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + prov := linkedca.MustProvisionerFromContext(r.Context()) + assert.NotNil(t, prov) + assert.Equals(t, "provID", prov.GetId()) + assert.Equals(t, "provName", prov.GetName()) + w.Write(nil) // mock response with status 200 + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + loadProvisionerByName(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} + +func TestHandler_checkAction(t *testing.T) { + type test struct { + adminDB admin.DB + next http.HandlerFunc + supportedInStandalone bool + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "standalone-nosql-supported": func(t *testing.T) test { + return test{ + supportedInStandalone: true, + adminDB: &nosql.DB{}, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + } + }, + "standalone-nosql-not-supported": func(t *testing.T) test { + err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode") + err.Message = "operation not supported in standalone mode" + return test{ + supportedInStandalone: false, + adminDB: &nosql.DB{}, + statusCode: 501, + err: err, + } + }, + "standalone-no-nosql-not-supported": func(t *testing.T) test { + err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported") + err.Message = "operation not supported" + return test{ + supportedInStandalone: false, + adminDB: &admin.MockDB{}, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + err: err, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + ctx := admin.NewContext(context.Background(), tc.adminDB) + req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx) + w := httptest.NewRecorder() + checkAction(tc.next, tc.supportedInStandalone)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} + +func TestHandler_loadExternalAccountKey(t *testing.T) { + type test struct { + ctx context.Context + acmeDB acme.DB + next http.HandlerFunc + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/keyID-not-found-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("keyID", "key") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "key", keyID) + return nil, acme.ErrNotFound + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/keyID-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("keyID", "key") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key") + err.Message = "error retrieving ACME External Account Key: force" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "key", keyID) + return nil, errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/reference-not-found-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return nil, acme.ErrNotFound + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/reference-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key") + err.Message = "error retrieving ACME External Account Key: force" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return nil, errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/no-key": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return nil, nil + }, + }, + err: err, + statusCode: 404, + } + }, + "ok/keyID": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("keyID", "eakID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + createdAt := time.Now().Add(-1 * time.Hour) + var boundAt time.Time + eak := &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + CreatedAt: createdAt, + BoundAt: boundAt, + } + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "eakID", keyID) + return eak, nil + }, + }, + next: func(w http.ResponseWriter, r *http.Request) { + contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context()) + assert.NotNil(t, eak) + exp := &linkedca.EABKey{ + Id: "eakID", + Provisioner: "provID", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAt), + } + assert.Equals(t, exp, contextEAK) + w.Write(nil) // mock response with status 200 + }, + err: nil, + statusCode: 200, + } + }, + "ok/reference": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + createdAt := time.Now().Add(-1 * time.Hour) + var boundAt time.Time + eak := &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "ref", + CreatedAt: createdAt, + BoundAt: boundAt, + } + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return eak, nil + }, + }, + next: func(w http.ResponseWriter, r *http.Request) { + contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context()) + assert.NotNil(t, eak) + exp := &linkedca.EABKey{ + Id: "eakID", + Provisioner: "provID", + Reference: "ref", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAt), + } + assert.Equals(t, exp, contextEAK) + w.Write(nil) // mock response with status 200 + }, + err: nil, + statusCode: 200, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB) + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.extractAuthorizeTokenAdmin(tc.next)(w, req) + loadExternalAccountKey(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go new file mode 100644 index 00000000..a478c83c --- /dev/null +++ b/authority/admin/api/policy.go @@ -0,0 +1,496 @@ +package api + +import ( + "context" + "errors" + "net/http" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/policy" +) + +// PolicyAdminResponder is the interface responsible for writing ACME admin +// responses. +type PolicyAdminResponder interface { + GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) + CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) + UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) + DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) + GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) + CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) + UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) + DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) + GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) +} + +// policyAdminResponder implements PolicyAdminResponder. +type policyAdminResponder struct{} + +// NewACMEAdminResponder returns a new PolicyAdminResponder. +func NewPolicyAdminResponder() PolicyAdminResponder { + return &policyAdminResponder{} +} + +// GetAuthorityPolicy handles the GET /admin/authority/policy request +func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(r.Context()) + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + return + } + + if authorityPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + return + } + + render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK) +} + +// CreateAuthorityPolicy handles the POST /admin/authority/policy request +func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + return + } + + if authorityPolicy != nil { + adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy") + render.Error(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) + return + } + + adm := linkedca.MustAdminFromContext(ctx) + + var createdPolicy *linkedca.Policy + if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error storing authority policy")) + return + } + + render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated) +} + +// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request +func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + return + } + + if authorityPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) + return + } + + adm := linkedca.MustAdminFromContext(ctx) + + var updatedPolicy *linkedca.Policy + if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error updating authority policy")) + return + } + + render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK) +} + +// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request +func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + return + } + + if authorityPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + return + } + + if err := auth.RemoveAuthorityPolicy(ctx); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) + return + } + + render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request +func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + provisionerPolicy := prov.GetPolicy() + if provisionerPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + return + } + + render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK) +} + +// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request +func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + provisionerPolicy := prov.GetPolicy() + if provisionerPolicy != nil { + adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) + render.Error(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) + return + } + + prov.Policy = newPolicy + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) +} + +// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request +func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + provisionerPolicy := prov.GetPolicy() + if provisionerPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) + return + } + + prov.Policy = newPolicy + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusOK) +} + +// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request +func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + if prov.Policy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + return + } + + // remove the policy + prov.Policy = nil + + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) + return + } + + render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + eakPolicy := eak.GetPolicy() + if eakPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + return + } + + render.ProtoJSONStatus(w, eakPolicy, http.StatusOK) +} + +func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + eakPolicy := eak.GetPolicy() + if eakPolicy != nil { + adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) + render.Error(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) + return + } + + eak.Policy = newPolicy + + acmeEAK := linkedEAKToCertificates(eak) + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) +} + +func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + eakPolicy := eak.GetPolicy() + if eakPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) + return + } + + eak.Policy = newPolicy + acmeEAK := linkedEAKToCertificates(eak) + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusOK) +} + +func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(ctx) + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + eakPolicy := eak.GetPolicy() + if eakPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + return + } + + // remove the policy + eak.Policy = nil + + acmeEAK := linkedEAKToCertificates(eak) + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) + return + } + + render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +// blockLinkedCA blocks all API operations on linked deployments +func blockLinkedCA(ctx context.Context) error { + // temporary blocking linked deployments + adminDB := admin.MustFromContext(ctx) + if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() { + return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + } + return nil +} + +// isBadRequest checks if an error should result in a bad request error +// returned to the client. +func isBadRequest(err error) bool { + var pe *authority.PolicyError + isPolicyError := errors.As(err, &pe) + return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure) +} + +func validatePolicy(p *linkedca.Policy) error { + + // convert the policy; return early if nil + options := policy.LinkedToCertificates(p) + if options == nil { + return nil + } + + var err error + + // Initialize a temporary x509 allow/deny policy engine + if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil { + return err + } + + // Initialize a temporary SSH allow/deny policy engine for host certificates + if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { + return err + } + + // Initialize a temporary SSH allow/deny policy engine for user certificates + if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { + return err + } + + return nil +} diff --git a/authority/admin/api/policy_test.go b/authority/admin/api/policy_test.go new file mode 100644 index 00000000..1ec88fb6 --- /dev/null +++ b/authority/admin/api/policy_test.go @@ -0,0 +1,2775 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/encoding/protojson" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" +) + +type fakeLinkedCA struct { + admin.MockDB +} + +func (f *fakeLinkedCA) IsLinkedCA() bool { + return true +} + +// testAdminError is an error type that models the expected +// error body returned. +type testAdminError struct { + Type string `json:"type"` + Message string `json:"message"` + Detail string `json:"detail"` +} + +type testX509Policy struct { + Allow *testX509Names `json:"allow,omitempty"` + Deny *testX509Names `json:"deny,omitempty"` + AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` +} + +type testX509Names struct { + CommonNames []string `json:"commonNames,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ips,omitempty"` + EmailAddresses []string `json:"emails,omitempty"` + URIDomains []string `json:"uris,omitempty"` +} + +type testSSHPolicy struct { + User *testSSHUserPolicy `json:"user,omitempty"` + Host *testSSHHostPolicy `json:"host,omitempty"` +} + +type testSSHHostPolicy struct { + Allow *testSSHHostNames `json:"allow,omitempty"` + Deny *testSSHHostNames `json:"deny,omitempty"` +} + +type testSSHHostNames struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ips,omitempty"` + Principals []string `json:"principals,omitempty"` +} + +type testSSHUserPolicy struct { + Allow *testSSHUserNames `json:"allow,omitempty"` + Deny *testSSHUserNames `json:"deny,omitempty"` +} + +type testSSHUserNames struct { + EmailAddresses []string `json:"emails,omitempty"` + Principals []string `json:"principals,omitempty"` +} + +// testPolicyResponse models the Policy API JSON response +type testPolicyResponse struct { + X509 *testX509Policy `json:"x509,omitempty"` + SSH *testSSHPolicy `json:"ssh,omitempty"` +} + +func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/auth.GetAuthorityPolicy-not-found": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") + err.Message = "authority policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + err: err, + statusCode: 404, + } + }, + "ok": func(t *testing.T) test { + ctx := context.Background() + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + } + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + }, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.GetAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + }) + } +} + +func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/existing-policy": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorConflictType, "authority already has a policy") + err.Message = "authority already has a policy" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{}, nil + }, + }, + err: err, + statusCode: 409, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/CreateAuthorityPolicy-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error storing authority policy") + adminErr.Message = "error storing authority policy: admin lock out" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/CreateAuthorityPolicy-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error storing authority policy: force") + adminErr.Message = "error storing authority policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return policy, nil + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 201, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.CreateAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") + err.Message = "authority policy does not exist" + err.Status = http.StatusNotFound + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/UpdateAuthorityPolicy-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating authority policy: force") + adminErr.Message = "error updating authority policy: admin lock out" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/UpdateAuthorityPolicy-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating authority policy: force") + adminErr.Message = "error updating authority policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return policy, nil + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.UpdateAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + statusCode int + } + + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") + err.Message = "authority policy does not exist" + err.Status = http.StatusNotFound + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/auth.RemoveAuthorityPolicy-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + err := admin.NewErrorISE("error deleting authority policy: force") + err.Message = "error deleting authority policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockRemoveAuthorityPolicy: func(ctx context.Context) error { + return errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockRemoveAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.DeleteAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + res.Body.Close() + response := DeleteResponse{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + ctx context.Context + acmeDB acme.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/prov-no-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{} + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") + err.Message = "provisioner policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 404, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + } + prov := &linkedca.Provisioner{ + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.GetProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/existing-policy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorConflictType, "provisioner provName already has a policy") + err.Message = "provisioner provName already has a policy" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 409, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error creating provisioner policy") + adminErr.Message = "error creating provisioner policy: admin lock out" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating provisioner policy: force") + adminErr.Message = "error creating provisioner policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 201, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.CreateProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + body []byte + adminDB admin.DB + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") + err.Message = "provisioner policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 404, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating provisioner policy") + adminErr.Message = "error updating provisioner policy: admin lock out" + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating provisioner policy: force") + adminErr.Message = "error updating provisioner policy: force" + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.UpdateProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + statusCode int + } + + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") + err.Message = "provisioner policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 404, + } + }, + "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: &linkedca.Policy{}, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewErrorISE("error deleting provisioner policy: force") + err.Message = "error deleting provisioner policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: &linkedca.Policy{}, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return nil + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.DeleteProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + res.Body.Close() + response := DeleteResponse{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { + type test struct { + ctx context.Context + acmeDB acme.DB + adminDB admin.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") + err.Message = "ACME EAK policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 404, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.GetACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { + type test struct { + acmeDB acme.DB + adminDB admin.DB + body []byte + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/existing-policy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorConflictType, "ACME EAK eakID already has a policy") + err.Message = "ACME EAK eakID already has a policy" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 409, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating ACME EAK policy") + adminErr.Message = "error creating ACME EAK policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return errors.New("force") + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 201, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.CreateACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { + type test struct { + acmeDB acme.DB + adminDB admin.DB + body []byte + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") + err.Message = "ACME EAK policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 404, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating ACME EAK policy: force") + adminErr.Message = "error updating ACME EAK policy: force" + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return errors.New("force") + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.UpdateACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { + type test struct { + body []byte + adminDB admin.DB + ctx context.Context + acmeDB acme.DB + err *admin.Error + statusCode int + } + + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") + err.Message = "ACME EAK policy does not exist" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + err: err, + statusCode: 404, + } + }, + "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewErrorISE("error deleting ACME EAK policy: force") + err.Message = "error deleting ACME EAK policy: force" + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + return test{ + ctx: ctx, + adminDB: &admin.MockDB{}, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return nil + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + par.DeleteACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + res.Body.Close() + response := DeleteResponse{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func Test_isBadRequest(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil", + err: nil, + want: false, + }, + { + name: "no-policy-error", + err: errors.New("some error"), + want: false, + }, + { + name: "no-bad-request", + err: &authority.PolicyError{ + Typ: authority.InternalFailure, + Err: errors.New("error"), + }, + want: false, + }, + { + name: "bad-request", + err: &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isBadRequest(tt.err); got != tt.want { + t.Errorf("isBadRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_validatePolicy(t *testing.T) { + type args struct { + p *linkedca.Policy + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "nil", + args: args{ + p: nil, + }, + wantErr: false, + }, + { + name: "x509", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"**.local"}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ssh user", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@@example.com"}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ssh host", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"**.local"}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + }, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validatePolicy(tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("validatePolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 1cad62dd..149f2c6a 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -23,29 +23,31 @@ type GetProvisionersResponse struct { } // GetProvisioner returns the requested provisioner, or an error. -func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func GetProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + ctx := r.Context() + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, err) return @@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { } // GetProvisioners returns the given segment of provisioners associated with the authority. -func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { +func GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { return } - p, next, err := h.auth.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { } // CreateProvisioner creates a new prov. -func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { +func CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { render.Error(w, err) @@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { + if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } @@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { } // DeleteProvisioner deletes a provisioner. -func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func DeleteProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(r.Context()) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { + if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } @@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { } // UpdateProvisioner updates an existing prov. -func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { +func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { render.Error(w, err) return } + ctx := r.Context() name := chi.URLParam(r, "name") - _old, err := h.auth.LoadProvisionerByName(name) + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + + p, err := auth.LoadProvisionerByName(name) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } - old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) + old, err := db.GetProvisioner(r.Context(), p.GetID()) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) return } @@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { + if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { render.Error(w, err) return } diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 6d5024f2..d050bca6 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -8,18 +8,21 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/types/known/timestamppb" ) func TestHandler_GetProvisioner(t *testing.T) { @@ -47,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -71,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -153,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } - req := tc.req.WithContext(tc.ctx) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + req := tc.req.WithContext(ctx) w := httptest.NewRecorder() - h.GetProvisioner(w, req) + GetProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -277,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetProvisioners(w, req) + GetProvisioners(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -335,12 +336,12 @@ func TestHandler_CreateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, - statusCode: 500, - err: &admin.Error{ // TODO(hs): this probably needs a better error - Type: "", - Status: 500, - Detail: "", - Message: "", + statusCode: 400, + err: &admin.Error{ + Type: "badRequest", + Status: 400, + Detail: "bad request", + Message: "proto: syntax error (line 1:2): invalid value !", }, } }, @@ -402,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateProvisioner(w, req) + CreateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -423,9 +422,15 @@ func TestHandler_CreateProvisioner(t *testing.T) { assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) - assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(adminErr.Message, "syntax error")) + } else { + assert.Equals(t, tc.err.Message, adminErr.Message) + } + return } @@ -562,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteProvisioner(w, req) + DeleteProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -616,12 +619,13 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, - statusCode: 500, - err: &admin.Error{ // TODO(hs): this probably needs a better error - Type: "", - Status: 500, - Detail: "", - Message: "", + adminDB: &admin.MockDB{}, + statusCode: 400, + err: &admin.Error{ + Type: "badRequest", + Status: 400, + Detail: "bad request", + Message: "proto: syntax error (line 1:2): invalid value !", }, } }, @@ -645,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: ctx, body: body, + adminDB: &admin.MockDB{}, auth: auth, statusCode: 500, err: &admin.Error{ @@ -1052,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.UpdateProvisioner(w, req) + UpdateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -1074,9 +1077,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) { assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) - assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(adminErr.Message, "syntax error")) + } else { + assert.Equals(t, tc.err.Message, adminErr.Message) + } + return } diff --git a/authority/admin/db.go b/authority/admin/db.go index bf34a3c2..b331cc0a 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -69,6 +69,34 @@ type DB interface { GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error DeleteAdmin(ctx context.Context, id string) error + + CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + DeleteAuthorityPolicy(ctx context.Context) error +} + +type dbKey struct{} + +// NewContext adds the given admin database to the context. +func NewContext(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current admin database from the given context. +func FromContext(ctx context.Context) (db DB, ok bool) { + db, ok = ctx.Value(dbKey{}).(DB) + return +} + +// MustFromContext returns the current admin database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) DB { + if db, ok := FromContext(ctx); !ok { + panic("admin database is not in the context") + } else { + return db + } } // MockDB is an implementation of the DB interface that should only be used as @@ -86,6 +114,11 @@ type MockDB struct { MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockDeleteAdmin func(ctx context.Context, id string) error + MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) + MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockDeleteAuthorityPolicy func(ctx context.Context) error + MockError error MockRet1 interface{} } @@ -179,3 +212,35 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error { } return m.MockError } + +// CreateAuthorityPolicy mock +func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + if m.MockCreateAuthorityPolicy != nil { + return m.MockCreateAuthorityPolicy(ctx, policy) + } + return m.MockError +} + +// GetAuthorityPolicy mock +func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + if m.MockGetAuthorityPolicy != nil { + return m.MockGetAuthorityPolicy(ctx) + } + return m.MockRet1.(*linkedca.Policy), m.MockError +} + +// UpdateAuthorityPolicy mock +func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + if m.MockUpdateAuthorityPolicy != nil { + return m.MockUpdateAuthorityPolicy(ctx, policy) + } + return m.MockError +} + +// DeleteAuthorityPolicy mock +func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error { + if m.MockDeleteAuthorityPolicy != nil { + return m.MockDeleteAuthorityPolicy(ctx) + } + return m.MockError +} diff --git a/authority/admin/db/nosql/nosql.go b/authority/admin/db/nosql/nosql.go index 22b049f5..32e05d92 100644 --- a/authority/admin/db/nosql/nosql.go +++ b/authority/admin/db/nosql/nosql.go @@ -11,8 +11,9 @@ import ( ) var ( - adminsTable = []byte("admins") - provisionersTable = []byte("provisioners") + adminsTable = []byte("admins") + provisionersTable = []byte("provisioners") + authorityPoliciesTable = []byte("authority_policies") ) // DB is a struct that implements the AdminDB interface. @@ -23,7 +24,7 @@ type DB struct { // New configures and returns a new Authority DB backend implemented using a nosql DB. func New(db nosqlDB.DB, authorityID string) (*DB, error) { - tables := [][]byte{adminsTable, provisionersTable} + tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable} for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", diff --git a/authority/admin/db/nosql/policy.go b/authority/admin/db/nosql/policy.go new file mode 100644 index 00000000..d4f2e9f9 --- /dev/null +++ b/authority/admin/db/nosql/policy.go @@ -0,0 +1,339 @@ +package nosql + +import ( + "context" + "encoding/json" + "fmt" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/nosql" +) + +type dbX509Policy struct { + Allow *dbX509Names `json:"allow,omitempty"` + Deny *dbX509Names `json:"deny,omitempty"` + AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` +} + +type dbX509Names struct { + CommonNames []string `json:"cn,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +type dbSSHPolicy struct { + // User contains SSH user certificate options. + User *dbSSHUserPolicy `json:"user,omitempty"` + // Host contains SSH host certificate options. + Host *dbSSHHostPolicy `json:"host,omitempty"` +} + +type dbSSHHostPolicy struct { + Allow *dbSSHHostNames `json:"allow,omitempty"` + Deny *dbSSHHostNames `json:"deny,omitempty"` +} + +type dbSSHHostNames struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +type dbSSHUserPolicy struct { + Allow *dbSSHUserNames `json:"allow,omitempty"` + Deny *dbSSHUserNames `json:"deny,omitempty"` +} + +type dbSSHUserNames struct { + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +type dbPolicy struct { + X509 *dbX509Policy `json:"x509,omitempty"` + SSH *dbSSHPolicy `json:"ssh,omitempty"` +} + +type dbAuthorityPolicy struct { + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + Policy *dbPolicy `json:"policy,omitempty"` +} + +func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy { + if dbap == nil { + return nil + } + return dbToLinked(dbap.Policy) +} + +func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) { + data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID)) + if nosql.IsErrNotFound(err) { + return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found") + } else if err != nil { + return nil, fmt.Errorf("error loading authority policy: %w", err) + } + return data, nil +} + +func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) { + if len(data) == 0 { + return nil, nil + } + var dba = new(dbAuthorityPolicy) + if err := json.Unmarshal(data, dba); err != nil { + return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err) + } + return dba, nil +} + +func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) { + data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID) + if err != nil { + return nil, err + } + dbap, err := db.unmarshalDBAuthorityPolicy(data) + if err != nil { + return nil, err + } + if dbap == nil { + return nil, nil + } + if dbap.AuthorityID != authorityID { + return nil, admin.NewError(admin.ErrorAuthorityMismatchType, + "authority policy is not owned by authority %s", authorityID) + } + return dbap, nil +} + +func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + + dbap := &dbAuthorityPolicy{ + ID: db.authorityID, + AuthorityID: db.authorityID, + Policy: linkedToDB(policy), + } + + if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error creating authority policy") + } + + return nil +} + +func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return nil, err + } + + return dbap.convert(), nil +} + +func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + + dbap := &dbAuthorityPolicy{ + ID: db.authorityID, + AuthorityID: db.authorityID, + Policy: linkedToDB(policy), + } + + if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error updating authority policy") + } + + return nil +} + +func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error { + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + + if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error deleting authority policy") + } + + return nil +} + +func dbToLinked(p *dbPolicy) *linkedca.Policy { + if p == nil { + return nil + } + r := &linkedca.Policy{} + if x509 := p.X509; x509 != nil { + r.X509 = &linkedca.X509Policy{} + if allow := x509.Allow; allow != nil { + r.X509.Allow = &linkedca.X509Names{} + r.X509.Allow.Dns = allow.DNSDomains + r.X509.Allow.Emails = allow.EmailAddresses + r.X509.Allow.Ips = allow.IPRanges + r.X509.Allow.Uris = allow.URIDomains + r.X509.Allow.CommonNames = allow.CommonNames + } + if deny := x509.Deny; deny != nil { + r.X509.Deny = &linkedca.X509Names{} + r.X509.Deny.Dns = deny.DNSDomains + r.X509.Deny.Emails = deny.EmailAddresses + r.X509.Deny.Ips = deny.IPRanges + r.X509.Deny.Uris = deny.URIDomains + r.X509.Deny.CommonNames = deny.CommonNames + } + r.X509.AllowWildcardNames = x509.AllowWildcardNames + } + if ssh := p.SSH; ssh != nil { + r.Ssh = &linkedca.SSHPolicy{} + if host := ssh.Host; host != nil { + r.Ssh.Host = &linkedca.SSHHostPolicy{} + if allow := host.Allow; allow != nil { + r.Ssh.Host.Allow = &linkedca.SSHHostNames{} + r.Ssh.Host.Allow.Dns = allow.DNSDomains + r.Ssh.Host.Allow.Ips = allow.IPRanges + r.Ssh.Host.Allow.Principals = allow.Principals + } + if deny := host.Deny; deny != nil { + r.Ssh.Host.Deny = &linkedca.SSHHostNames{} + r.Ssh.Host.Deny.Dns = deny.DNSDomains + r.Ssh.Host.Deny.Ips = deny.IPRanges + r.Ssh.Host.Deny.Principals = deny.Principals + } + } + if user := ssh.User; user != nil { + r.Ssh.User = &linkedca.SSHUserPolicy{} + if allow := user.Allow; allow != nil { + r.Ssh.User.Allow = &linkedca.SSHUserNames{} + r.Ssh.User.Allow.Emails = allow.EmailAddresses + r.Ssh.User.Allow.Principals = allow.Principals + } + if deny := user.Deny; deny != nil { + r.Ssh.User.Deny = &linkedca.SSHUserNames{} + r.Ssh.User.Deny.Emails = deny.EmailAddresses + r.Ssh.User.Deny.Principals = deny.Principals + } + } + } + + return r +} + +func linkedToDB(p *linkedca.Policy) *dbPolicy { + + if p == nil { + return nil + } + + // return early if x509 nor SSH is set + if p.GetX509() == nil && p.GetSsh() == nil { + return nil + } + + r := &dbPolicy{} + // fill x509 policy configuration + if x509 := p.GetX509(); x509 != nil { + r.X509 = &dbX509Policy{} + if allow := x509.GetAllow(); allow != nil { + r.X509.Allow = &dbX509Names{} + if allow.Dns != nil { + r.X509.Allow.DNSDomains = allow.Dns + } + if allow.Ips != nil { + r.X509.Allow.IPRanges = allow.Ips + } + if allow.Emails != nil { + r.X509.Allow.EmailAddresses = allow.Emails + } + if allow.Uris != nil { + r.X509.Allow.URIDomains = allow.Uris + } + if allow.CommonNames != nil { + r.X509.Allow.CommonNames = allow.CommonNames + } + } + if deny := x509.GetDeny(); deny != nil { + r.X509.Deny = &dbX509Names{} + if deny.Dns != nil { + r.X509.Deny.DNSDomains = deny.Dns + } + if deny.Ips != nil { + r.X509.Deny.IPRanges = deny.Ips + } + if deny.Emails != nil { + r.X509.Deny.EmailAddresses = deny.Emails + } + if deny.Uris != nil { + r.X509.Deny.URIDomains = deny.Uris + } + if deny.CommonNames != nil { + r.X509.Deny.CommonNames = deny.CommonNames + } + } + + r.X509.AllowWildcardNames = x509.GetAllowWildcardNames() + } + + // fill ssh policy configuration + if ssh := p.GetSsh(); ssh != nil { + r.SSH = &dbSSHPolicy{} + if host := ssh.GetHost(); host != nil { + r.SSH.Host = &dbSSHHostPolicy{} + if allow := host.GetAllow(); allow != nil { + r.SSH.Host.Allow = &dbSSHHostNames{} + if allow.Dns != nil { + r.SSH.Host.Allow.DNSDomains = allow.Dns + } + if allow.Ips != nil { + r.SSH.Host.Allow.IPRanges = allow.Ips + } + if allow.Principals != nil { + r.SSH.Host.Allow.Principals = allow.Principals + } + } + if deny := host.GetDeny(); deny != nil { + r.SSH.Host.Deny = &dbSSHHostNames{} + if deny.Dns != nil { + r.SSH.Host.Deny.DNSDomains = deny.Dns + } + if deny.Ips != nil { + r.SSH.Host.Deny.IPRanges = deny.Ips + } + if deny.Principals != nil { + r.SSH.Host.Deny.Principals = deny.Principals + } + } + } + if user := ssh.GetUser(); user != nil { + r.SSH.User = &dbSSHUserPolicy{} + if allow := user.GetAllow(); allow != nil { + r.SSH.User.Allow = &dbSSHUserNames{} + if allow.Emails != nil { + r.SSH.User.Allow.EmailAddresses = allow.Emails + } + if allow.Principals != nil { + r.SSH.User.Allow.Principals = allow.Principals + } + } + if deny := user.GetDeny(); deny != nil { + r.SSH.User.Deny = &dbSSHUserNames{} + if deny.Emails != nil { + r.SSH.User.Deny.EmailAddresses = deny.Emails + } + if deny.Principals != nil { + r.SSH.User.Deny.Principals = deny.Principals + } + } + } + } + + return r +} diff --git a/authority/admin/db/nosql/policy_test.go b/authority/admin/db/nosql/policy_test.go new file mode 100644 index 00000000..3ffded6b --- /dev/null +++ b/authority/admin/db/nosql/policy_test.go @@ -0,0 +1,1206 @@ +package nosql + +import ( + "context" + "encoding/json" + "errors" + "reflect" + "testing" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + "go.step.sm/linkedca" +) + +func TestDB_getDBAuthorityPolicyBytes(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte("foo"), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if b, err := d.getDBAuthorityPolicyBytes(tc.ctx, tc.authorityID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, string(b), "foo") + } + }) + } +} + +func TestDB_getDBAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + dbap *dbAuthorityPolicy + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling policy bytes into dbAuthorityPolicy"), + } + }, + "fail/authorityID-error": func(t *testing.T) test { + dbp := &dbAuthorityPolicy{ + ID: "ID", + AuthorityID: "diffAuthID", + Policy: linkedToDB(&linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }), + } + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return b, nil + }, + }, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "authority policy is not owned by authority authID"), + } + }, + "ok/empty-bytes": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte{}, nil + }, + }, + } + }, + "ok": func(t *testing.T) test { + dbap := &dbAuthorityPolicy{ + ID: "ID", + AuthorityID: authID, + Policy: linkedToDB(&linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }), + } + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return b, nil + }, + }, + dbap: dbap, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + dbp, err := d.getDBAuthorityPolicy(tc.ctx, tc.authorityID) + switch { + case err != nil: + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) && tc.dbap == nil: + assert.Nil(t, dbp) + case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr): + assert.Equals(t, dbp.ID, "ID") + assert.Equals(t, dbp.AuthorityID, tc.dbap.AuthorityID) + assert.Equals(t, dbp.Policy, tc.dbap.Policy) + } + }) + } +} + +func TestDB_CreateAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error creating authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, old, nil) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.CreateAuthorityPolicy(tc.ctx, tc.policy); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_GetAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(policy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + got, err := d.GetAuthorityPolicy(tc.ctx) + if err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + + assert.NotNil(t, got) + assert.Equals(t, tc.policy, got) + }) + } +} + +func TestDB_UpdateAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error updating authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.UpdateAuthorityPolicy(tc.ctx, tc.policy); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + }) + } +} + +func TestDB_DeleteAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + assert.Equals(t, nil, nu) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error deleting authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + assert.Equals(t, nil, nu) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.DeleteAuthorityPolicy(tc.ctx); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + }) + } +} + +func Test_linkedToDB(t *testing.T) { + type args struct { + p *linkedca.Policy + } + tests := []struct { + name string + args args + want *dbPolicy + }{ + { + name: "nil policy", + args: args{ + p: nil, + }, + want: nil, + }, + { + name: "no x509 nor ssh", + args: args{ + p: &linkedca.Policy{}, + }, + want: nil, + }, + { + name: "x509", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + want: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + { + name: "ssh user", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + want: &dbPolicy{ + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + { + name: "full ssh policy", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &dbPolicy{ + SSH: &dbSSHPolicy{ + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + { + name: "full policy", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := linkedToDB(tt.args.p); !reflect.DeepEqual(got, tt.want) { + t.Errorf("linkedToDB() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_dbToLinked(t *testing.T) { + type args struct { + p *dbPolicy + } + tests := []struct { + name string + args args + want *linkedca.Policy + }{ + { + name: "nil policy", + args: args{ + p: nil, + }, + want: nil, + }, + { + name: "x509", + args: args{ + p: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + { + name: "ssh user", + args: args{ + p: &dbPolicy{ + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + want: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + { + name: "ssh host", + args: args{ + p: &dbPolicy{ + SSH: &dbSSHPolicy{ + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + { + name: "full policy", + args: args{ + p: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dbToLinked(tt.args.p); !reflect.DeepEqual(got, tt.want) { + t.Errorf("dbToLinked() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/admin/errors.go b/authority/admin/errors.go index baa32dd9..2cf0c0e5 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -24,10 +24,12 @@ const ( ErrorBadRequestType // ErrorNotImplementedType not implemented. ErrorNotImplementedType - // ErrorUnauthorizedType internal server error. + // ErrorUnauthorizedType unauthorized. ErrorUnauthorizedType // ErrorServerInternalType internal server error. ErrorServerInternalType + // ErrorConflictType conflict. + ErrorConflictType ) // String returns the string representation of the admin problem type, @@ -48,6 +50,8 @@ func (ap ProblemType) String() string { return "unauthorized" case ErrorServerInternalType: return "internalServerError" + case ErrorConflictType: + return "conflict" default: return fmt.Sprintf("unsupported error type '%d'", int(ap)) } @@ -64,7 +68,7 @@ var ( errorServerInternalMetadata = errorMetadata{ typ: ErrorServerInternalType.String(), details: "the server experienced an internal error", - status: 500, + status: http.StatusInternalServerError, } errorMap = map[ProblemType]errorMetadata{ ErrorNotFoundType: { @@ -98,6 +102,11 @@ var ( status: http.StatusUnauthorized, }, ErrorServerInternalType: errorServerInternalMetadata, + ErrorConflictType: { + typ: ErrorConflictType.String(), + details: "conflict", + status: http.StatusConflict, + }, } ) diff --git a/authority/administrator/collection.go b/authority/administrator/collection.go index 88d7bb2c..f40e7417 100644 --- a/authority/administrator/collection.go +++ b/authority/administrator/collection.go @@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv { return subProv{subject, prov} } -// LoadBySubProv a admin by the subject and provisioner name. +// LoadBySubProv loads an admin by subject and provisioner name. func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) { return loadAdmin(c.bySubProv, newSubProv(sub, provName)) } -// LoadByProvisioner a admin by the subject and provisioner name. +// LoadByProvisioner loads admins by provisioner name. func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) { val, ok := c.byProv.Load(provName) if !ok { @@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool } // Store adds an admin to the collection and enforces the uniqueness of -// admin IDs and amdin subject <-> provisioner name combos. +// admin IDs and admin subject <-> provisioner name combos. func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { // Input validation. if adm.ProvisionerId != prov.GetID() { diff --git a/authority/admins.go b/authority/admins.go index b975297a..c8e1ac66 100644 --- a/authority/admins.go +++ b/authority/admins.go @@ -49,7 +49,7 @@ func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov pr return admin.WrapErrorISE(err, "error creating admin") } if err := a.admins.Store(adm, prov); err != nil { - if err := a.reloadAdminResources(ctx); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store") } return admin.WrapErrorISE(err, "error storing admin in authority cache") @@ -66,7 +66,7 @@ func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Adm return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id) } if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil { - if err := a.reloadAdminResources(ctx); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update") } return nil, admin.WrapErrorISE(err, "error updating admin %s", id) @@ -88,7 +88,7 @@ func (a *Authority) removeAdmin(ctx context.Context, id string) error { return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id) } if err := a.adminDB.DeleteAdmin(ctx, id); err != nil { - if err := a.reloadAdminResources(ctx); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove") } return admin.WrapErrorISE(err, "error deleting admin %s", id) diff --git a/authority/authority.go b/authority/authority.go index 3adfc14b..dc416df7 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -13,10 +13,16 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/pemutil" + "go.step.sm/linkedca" + "github.com/smallstep/certificates/authority/admin" adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" "github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" @@ -27,9 +33,6 @@ import ( "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" "github.com/smallstep/nosql" - "go.step.sm/crypto/pemutil" - "go.step.sm/linkedca" - "golang.org/x/crypto/ssh" ) // Authority implements the Certificate Authority internal interface. @@ -81,9 +84,16 @@ type Authority struct { authorizeRenewFunc provisioner.AuthorizeRenewFunc authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc + // Policy engines + policyEngine *policy.Engine + adminMutex sync.RWMutex + + // Do Not initialize the authority + skipInit bool } +// Info contains information about the authority. type Info struct { StartTime time.Time RootX509Certs []*x509.Certificate @@ -111,9 +121,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { } } - // Initialize authority from options or configuration. - if err := a.init(); err != nil { - return nil, err + if !a.skipInit { + // Initialize authority from options or configuration. + if err := a.init(); err != nil { + return nil, err + } } return a, nil @@ -149,16 +161,41 @@ func NewEmbedded(opts ...Option) (*Authority, error) { // Initialize config required fields. a.config.Init() - // Initialize authority from options or configuration. - if err := a.init(); err != nil { - return nil, err + if !a.skipInit { + // Initialize authority from options or configuration. + if err := a.init(); err != nil { + return nil, err + } } return a, nil } -// reloadAdminResources reloads admins and provisioners from the DB. -func (a *Authority) reloadAdminResources(ctx context.Context) error { +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + if a, ok := FromContext(ctx); !ok { + panic("authority is not in the context") + } else { + return a + } +} + +// ReloadAdminResources reloads admins and provisioners from the DB. +func (a *Authority) ReloadAdminResources(ctx context.Context) error { var ( provList provisioner.List adminList []*linkedca.Admin @@ -213,6 +250,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error { a.provisioners = provClxn a.config.AuthorityConfig.Admins = adminList a.admins = adminClxn + return nil } @@ -224,6 +262,7 @@ func (a *Authority) init() error { } var err error + ctx := NewContext(context.Background(), a) // Set password if they are not set. var configPassword []byte @@ -259,10 +298,25 @@ func (a *Authority) init() error { if a.config.KMS != nil { options = *a.config.KMS } - a.keyManager, err = kms.New(context.Background(), options) + a.keyManager, err = kms.New(ctx, options) + if err != nil { + return err + } + } + + // Initialize linkedca client if necessary. On a linked RA, the issuer + // configuration might come from majordomo. + var linkedcaClient *linkedCaClient + if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil { + linkedcaClient, err = newLinkedCAClient(a.linkedCAToken) if err != nil { return err } + // If authorityId is configured make sure it matches the one in the token + if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) { + return errors.New("error initializing linkedca: token authority and configured authority do not match") + } + linkedcaClient.Run() } // Initialize the X.509 CA Service if it has not been set in the options. @@ -272,6 +326,22 @@ func (a *Authority) init() error { options = *a.config.AuthorityConfig.Options } + // Configure linked RA + if linkedcaClient != nil && options.CertificateAuthority == "" { + conf, err := linkedcaClient.GetConfiguration(ctx) + if err != nil { + return err + } + if conf.RaConfig != nil { + options.CertificateAuthority = conf.RaConfig.CaUrl + options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint + options.CertificateIssuer = &casapi.CertificateIssuer{ + Type: conf.RaConfig.Provisioner.Type.String(), + Provisioner: conf.RaConfig.Provisioner.Name, + } + } + } + // Set the issuer password if passed in the flags. if options.CertificateIssuer != nil && a.issuerPassword != nil { options.CertificateIssuer.Password = string(a.issuerPassword) @@ -292,7 +362,7 @@ func (a *Authority) init() error { } } - a.x509CAService, err = cas.New(context.Background(), options) + a.x509CAService, err = cas.New(ctx, options) if err != nil { return err } @@ -479,7 +549,7 @@ func (a *Authority) init() error { } } - a.scepService, err = scep.NewService(context.Background(), options) + a.scepService, err = scep.NewService(ctx, options) if err != nil { return err } @@ -491,40 +561,29 @@ func (a *Authority) init() error { // Initialize step-ca Admin Database if it's not already initialized using // WithAdminDB. if a.adminDB == nil { - if a.linkedCAToken == "" { - // Check if AuthConfig already exists - a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) - if err != nil { - return err - } + if linkedcaClient != nil { + a.adminDB = linkedcaClient } else { - // Use the linkedca client as the admindb. - client, err := newLinkedCAClient(a.linkedCAToken) + a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) if err != nil { return err } - // If authorityId is configured make sure it matches the one in the token - if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) { - return errors.New("error initializing linkedca: token authority and configured authority do not match") - } - client.Run() - a.adminDB = client } } - provs, err := a.adminDB.GetProvisioners(context.Background()) + provs, err := a.adminDB.GetProvisioners(ctx) if err != nil { return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") } if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { // Create First Provisioner - prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) + prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password)) if err != nil { return admin.WrapErrorISE(err, "error creating first provisioner") } // Create first admin - if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ + if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{ ProvisionerId: prov.Id, Subject: "step", Type: linkedca.Admin_SUPER_ADMIN, @@ -535,7 +594,12 @@ func (a *Authority) init() error { } // Load Provisioners and Admins - if err := a.reloadAdminResources(context.Background()); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { + return err + } + + // Load x509 and SSH Policy Engines + if err := a.reloadPolicyEngines(ctx); err != nil { return err } @@ -570,6 +634,15 @@ func (a *Authority) init() error { return nil } +// GetID returns the define authority id or a zero uuid. +func (a *Authority) GetID() string { + const zeroUUID = "00000000-0000-0000-0000-000000000000" + if id := a.config.AuthorityConfig.AuthorityID; id != "" { + return id + } + return zeroUUID +} + // GetDatabase returns the authority database. If the configuration does not // define a database, GetDatabase will return a db.SimpleDB instance. func (a *Authority) GetDatabase() db.AuthDB { @@ -581,6 +654,12 @@ func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } +// GetConfig returns the config. +func (a *Authority) GetConfig() *config.Config { + return a.config +} + +// GetInfo returns information about the authority. func (a *Authority) GetInfo() Info { ai := Info{ StartTime: a.startTime, diff --git a/authority/authority_test.go b/authority/authority_test.go index 1f63333d..9f35f23e 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" @@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) { }) } } + +func TestAuthority_GetID(t *testing.T) { + type fields struct { + authorityID string + } + tests := []struct { + name string + fields fields + want string + }{ + {"ok", fields{""}, "00000000-0000-0000-0000-000000000000"}, + {"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + AuthorityID: tt.fields.authorityID, + }, + }, + } + if got := a.GetID(); got != tt.want { + t.Errorf("Authority.GetID() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/authorize.go b/authority/authorize.go index 7c1c2ff6..91f1b3cb 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "fmt" "net/http" "net/url" "strconv" @@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool { return m } -// authorizeToken parses the token and returns the provisioner used to generate -// the token. This method enforces the One-Time use policy (tokens can only be -// used once). -func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { - // Validate payload +// getProvisionerFromToken extracts a provisioner from the given token without +// doing any token validation. +func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) { tok, err := jose.ParseSigned(token) if err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token") + return nil, nil, fmt.Errorf("error parsing token: %w", err) } // Get claims w/out verification. We need to look up the provisioner @@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // before we can look up the provisioner. var claims Claims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") + return nil, nil, fmt.Errorf("error unmarshaling token: %w", err) + } + + // This method will also validate the audiences for JWK provisioners. + p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) + if !ok { + return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) + } + + return p, &claims, nil +} + +// authorizeToken parses the token and returns the provisioner used to generate +// the token. This method enforces the One-Time use policy (tokens can only be +// used once). +func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { + p, claims, err := a.getProvisionerFromToken(token) + if err != nil { + return nil, errs.UnauthorizedErr(err) } // TODO: use new persistence layer abstraction. @@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { - return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority") + return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority") } } - // This method will also validate the audiences for JWK provisioners. - p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) - if !ok { - return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+ - "not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) - } - // Store the token to protect against reuse unless it's skipped. // If we cannot get a token id from the provisioner, just hash the token. if !SkipTokenReuseFromContext(ctx) { @@ -130,22 +140,24 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err := claims.ValidateWithLeeway(jose.Expected{ - Issuer: prov.GetName(), - Time: time.Now().UTC(), + Time: time.Now().UTC(), }, time.Minute); err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims") } // validate audience: path matches the current path - if r.URL.Path != claims.Audience[0] { - return nil, admin.NewError(admin.ErrorUnauthorizedType, - "x5c.authorizeToken; x5c token has invalid audience "+ - "claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience) + if !matchesAudience(claims.Audience, a.config.Audience(r.URL.Path)) { + return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid audience claim (aud)") + } + + // validate issuer: old versions used the provisioner name, new version uses + // 'step-admin-client/1.0' + if claims.Issuer != "step-admin-client/1.0" && claims.Issuer != prov.GetName() { + return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid issuer claim (iss)") } if claims.Subject == "" { - return nil, admin.NewError(admin.ErrorUnauthorizedType, - "x5c.authorizeToken; x5c token subject cannot be empty") + return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token subject cannot be empty") } var ( @@ -156,7 +168,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...) adminSANs = append(adminSANs, leaf.EmailAddresses...) for _, san := range adminSANs { - if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok { + if adm, ok = a.LoadAdminBySubProv(san, prov.GetName()); ok { adminFound = true break } @@ -186,11 +198,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error { } ok, err := a.db.UseToken(reuseKey, token) if err != nil { - return errs.Wrap(http.StatusInternalServerError, err, - "authority.authorizeToken: failed when attempting to store token") + return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token") } if !ok { - return errs.Unauthorized("authority.authorizeToken: token already used") + return errs.Unauthorized("token already used") } } return nil @@ -249,10 +260,10 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio // AuthorizeSign authorizes a signature request by validating and authenticating // a token that must be sent w/ the request. // -// NOTE: This method is deprecated and should not be used. We make it available -// in the short term os as not to break existing clients. +// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error). func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + ctx := NewContext(context.Background(), a) + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) return a.Authorize(ctx, token) } @@ -285,9 +296,16 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { if isRevoked { return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } - p, ok := a.provisioners.LoadByCertificate(cert) - if !ok { - return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) + p, err := a.LoadProvisionerByCertificate(cert) + if err != nil { + var ok bool + // For backward compatibility this method will also succeed if the + // certificate does not have a provisioner extension. LoadByCertificate + // returns the noop provisioner if this happens, and it allows + // certificate renewals. + if p, ok = a.provisioners.LoadByCertificate(cert); !ok { + return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) + } } if err := p.AuthorizeRenew(context.Background(), cert); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) @@ -386,8 +404,8 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509. return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token")) } - p, ok := a.provisioners.LoadByCertificate(leaf) - if !ok { + p, err := a.LoadProvisionerByCertificate(leaf) + if err != nil { return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") } if err := a.UseToken(ott, p); err != nil { @@ -395,7 +413,6 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509. } if err := claims.ValidateWithLeeway(jose.Expected{ - Issuer: p.GetName(), Subject: leaf.Subject.CommonName, Time: time.Now().UTC(), }, time.Minute); err != nil { @@ -420,6 +437,12 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509. return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) } + // validate issuer: old versions used the provisioner name, new version uses + // 'step-ca-client/1.0' + if claims.Issuer != "step-ca-client/1.0" && claims.Issuer != p.GetName() { + return nil, admin.NewError(admin.ErrorUnauthorizedType, "error validating renew token: invalid issuer claim (iss)") + } + return leaf, nil } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 81e542c5..af80d3d3 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeToken: error parsing token"), + err: errors.New("error parsing token"), code: http.StatusUnauthorized, } }, @@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: raw, - err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"), + err: errors.New("token issued before the bootstrap of certificate authority"), code: http.StatusUnauthorized, } }, @@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: raw, - err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"), + err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"), code: http.StatusUnauthorized, } }, @@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: failed when attempting to store token: force"), + err: errors.New("failed when attempting to store token: force"), code: http.StatusInternalServerError, } }, @@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -491,7 +491,7 @@ func TestAuthority_authorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 7, got) + assert.Equals(t, 9, len(got)) // number of provisioner.SignOptions returned } } }) @@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: context.Background(), - err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), - err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), - err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, @@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, @@ -847,6 +847,29 @@ func TestAuthority_authorizeRenew(t *testing.T) { cert: fooCrt, } }, + "ok/from db": func(t *testing.T) *authorizeTest { + a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsRevoked: func(key string) (bool, error) { + return false, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + p, ok := a.provisioners.LoadByName("step-cli") + if !ok { + t.Fatal("provisioner step-cli not found") + } + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: p.GetID(), + }, + }, nil + }, + } + return &authorizeTest{ + auth: a, + cert: fooCrt, + } + }, } for name, genTestCase := range tests { @@ -965,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1011,7 +1034,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 7, got) + assert.Len(t, 9, got) // number of provisioner.SignOptions returned } } }) @@ -1059,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1167,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1259,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1322,7 +1345,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.Serial, cert.Serial) - assert.Len(t, 3, signOpts) + assert.Len(t, 4, signOpts) } } }) @@ -1381,7 +1404,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { t1, c1 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { @@ -1400,7 +1423,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { t2, c2 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), IssuedAt: jose.NewNumericDate(now), @@ -1417,12 +1440,31 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { }) return nil })) - badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ + t3, c3 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-cli", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-ca-client/1.0", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) @@ -1439,7 +1481,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { @@ -1477,7 +1519,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { badSubject, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "bad-subject", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { @@ -1496,7 +1538,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { @@ -1515,7 +1557,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { @@ -1534,7 +1576,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)), @@ -1554,7 +1596,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { badAudience, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", - Issuer: "step-cli", + Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { @@ -1584,6 +1626,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { }{ {"ok", a1, args{ctx, t1}, c1, false}, {"ok expired cert", a1, args{ctx, t2}, c2, false}, + {"ok provisioner issuer", a1, args{ctx, t3}, c3, false}, {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, {"fail token reuse", a1, args{ctx, t1}, nil, true}, {"fail token signature", a1, args{ctx, badSigner}, nil, true}, diff --git a/authority/config/config.go b/authority/config/config.go index 4bf51cfe..337bb71c 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -8,12 +8,15 @@ import ( "time" "github.com/pkg/errors" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" cas "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" kms "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/certificates/templates" - "go.step.sm/linkedca" ) const ( @@ -26,27 +29,27 @@ var ( DefaultBackdate = time.Minute // DefaultDisableRenewal disables renewals per provisioner. DefaultDisableRenewal = false - // DefaultAllowRenewAfterExpiry allows renewals even if the certificate is + // DefaultAllowRenewalAfterExpiry allows renewals even if the certificate is // expired. - DefaultAllowRenewAfterExpiry = false + DefaultAllowRenewalAfterExpiry = false // DefaultEnableSSHCA enable SSH CA features per provisioner or globally // for all provisioners. DefaultEnableSSHCA = false // GlobalProvisionerClaims default claims for the Authority. Can be overridden // by provisioner specific claims. GlobalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &DefaultEnableSSHCA, - DisableRenewal: &DefaultDisableRenewal, - AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &DefaultEnableSSHCA, + DisableRenewal: &DefaultDisableRenewal, + AllowRenewalAfterExpiry: &DefaultAllowRenewalAfterExpiry, } ) @@ -68,6 +71,7 @@ type Config struct { TLS *TLSOptions `json:"tls,omitempty"` Password string `json:"password,omitempty"` Templates *templates.Templates `json:"templates,omitempty"` + CommonName string `json:"commonName,omitempty"` CRL *CRLConfig `json:"crl,omitempty"` } @@ -95,6 +99,7 @@ type AuthConfig struct { Admins []*linkedca.Admin `json:"-"` Template *ASN1DN `json:"template,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"` + Policy *policy.Options `json:"policy,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` Backdate *provisioner.Duration `json:"backdate,omitempty"` EnableAdmin bool `json:"enableAdmin,omitempty"` @@ -180,6 +185,9 @@ func (c *Config) Init() { if c.AuthorityConfig == nil { c.AuthorityConfig = &AuthConfig{} } + if c.CommonName == "" { + c.CommonName = "Step Online CA" + } c.AuthorityConfig.init() } @@ -311,6 +319,18 @@ func (c *Config) GetAudiences() provisioner.Audiences { return audiences } +// Audience returns the list of audiences for a given path. +func (c *Config) Audience(path string) []string { + audiences := make([]string, len(c.DNSNames)+1) + for i, name := range c.DNSNames { + hostname := toHostname(name) + audiences[i] = "https://" + hostname + path + } + // For backward compatibility + audiences[len(c.DNSNames)] = path + return audiences +} + func toHostname(name string) string { // ensure an IPv6 address is represented with square brackets when used as hostname if ip := net.ParseIP(name); ip != nil && ip.To4() == nil { diff --git a/authority/config/config_test.go b/authority/config/config_test.go index b921be13..5a05b3f6 100644 --- a/authority/config/config_test.go +++ b/authority/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "reflect" "testing" "github.com/pkg/errors" @@ -317,3 +318,38 @@ func Test_toHostname(t *testing.T) { }) } } + +func TestConfig_Audience(t *testing.T) { + type fields struct { + DNSNames []string + } + type args struct { + path string + } + tests := []struct { + name string + fields fields + args args + want []string + }{ + {"ok", fields{[]string{ + "ca", "ca.example.com", "127.0.0.1", "::1", + }}, args{"/path"}, []string{ + "https://ca/path", + "https://ca.example.com/path", + "https://127.0.0.1/path", + "https://[::1]/path", + "/path", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + DNSNames: tt.fields.DNSNames, + } + if got := c.Audience(tt.args.path); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.Audience() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/linkedca.go b/authority/linkedca.go index b568dcbb..0b98f877 100644 --- a/authority/linkedca.go +++ b/authority/linkedca.go @@ -15,15 +15,19 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/db" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/tlsutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" - "golang.org/x/crypto/ssh" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" ) const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$" @@ -34,6 +38,9 @@ type linkedCaClient struct { authorityID string } +// interface guard +var _ admin.DB = (*linkedCaClient)(nil) + type linkedCAClaims struct { jose.Claims SANs []string `json:"sans"` @@ -115,6 +122,13 @@ func newLinkedCAClient(token string) (*linkedCaClient, error) { }, nil } +// IsLinkedCA is a sentinel function that can be used to +// check if a linkedCaClient is the underlying type of an +// admin.DB interface. +func (c *linkedCaClient) IsLinkedCA() bool { + return true +} + func (c *linkedCaClient) Run() { c.renewer.Run() } @@ -151,13 +165,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked } func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { + resp, err := c.GetConfiguration(ctx) + if err != nil { + return nil, err + } + return resp.Provisioners, nil +} + +func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) { resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ AuthorityId: c.authorityID, }) if err != nil { - return nil, errors.Wrap(err, "error getting provisioners") + return nil, errors.Wrap(err, "error getting configuration") } - return resp.Provisioners, nil + return resp, nil } func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { @@ -204,11 +226,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm } func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { - resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ - AuthorityId: c.authorityID, - }) + resp, err := c.GetConfiguration(ctx) if err != nil { - return nil, errors.Wrap(err, "error getting admins") + return nil, err } return resp.Admins, nil } @@ -228,12 +248,35 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error { return errors.Wrap(err, "error deleting admin") } -func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error { +func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, error) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + resp, err := c.client.GetCertificate(ctx, &linkedca.GetCertificateRequest{ + Serial: serial, + }) + if err != nil { + return nil, err + } + + var pd *db.ProvisionerData + if p := resp.Provisioner; p != nil { + pd = &db.ProvisionerData{ + ID: p.Id, Name: p.Name, Type: p.Type.String(), + } + } + return &db.CertificateData{ + Provisioner: pd, + }, nil +} + +func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificateChain: serializeCertificateChain(fullchain[1:]...), + Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting certificate") } @@ -246,18 +289,30 @@ func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullc PemCertificateChain: serializeCertificateChain(fullchain[1:]...), PemParentCertificate: serializeCertificateChain(parent), }) - return errors.Wrap(err, "error posting certificate") + return errors.Wrap(err, "error posting renewed certificate") } -func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error { +func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ Certificate: string(ssh.MarshalAuthorizedKey(crt)), + Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting ssh certificate") } +func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ + Certificate: string(ssh.MarshalAuthorizedKey(crt)), + ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)), + Provisioner: createProvisionerIdentity(p), + }) + return errors.Wrap(err, "error posting renewed ssh certificate") +} + func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() @@ -310,6 +365,33 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { return resp.Status != linkedca.RevocationStatus_ACTIVE, nil } +func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") +} + +func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error { + return errors.New("not implemented yet") +} + +func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity { + if p == nil { + return nil + } + return &linkedca.ProvisionerIdentity{ + Id: p.GetID(), + Type: linkedca.Provisioner_Type(p.GetType()), + Name: p.GetName(), + } +} + func serializeCertificate(crt *x509.Certificate) string { if crt == nil { return "" diff --git a/authority/options.go b/authority/options.go index 1c154577..6e1949f5 100644 --- a/authority/options.go +++ b/authority/options.go @@ -266,6 +266,16 @@ func WithAdminDB(d admin.DB) Option { } } +// WithProvisioners is an option to set the provisioner collection. +// +// Deprecated: provisioner collections will likely change +func WithProvisioners(ps *provisioner.Collection) Option { + return func(a *Authority) error { + a.provisioners = ps + return nil + } +} + // WithLinkedCAToken is an option to set the authentication token used to enable // linked ca. func WithLinkedCAToken(token string) Option { @@ -284,6 +294,15 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option { } } +// WithSkipInit is an option that allows the constructor to skip initializtion +// of the authority. +func WithSkipInit() Option { + return func(a *Authority) error { + a.skipInit = true + return nil + } +} + func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) { var block *pem.Block var certs []*x509.Certificate diff --git a/authority/policy.go b/authority/policy.go new file mode 100644 index 00000000..258873af --- /dev/null +++ b/authority/policy.go @@ -0,0 +1,265 @@ +package authority + +import ( + "context" + "errors" + "fmt" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/admin" + authPolicy "github.com/smallstep/certificates/authority/policy" + policy "github.com/smallstep/certificates/policy" +) + +type policyErrorType int + +const ( + AdminLockOut policyErrorType = iota + 1 + StoreFailure + ReloadFailure + ConfigurationFailure + EvaluationFailure + InternalFailure +) + +type PolicyError struct { + Typ policyErrorType + Err error +} + +func (p *PolicyError) Error() string { + return p.Err.Error() +} + +func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + p, err := a.adminDB.GetAuthorityPolicy(ctx) + if err != nil { + return nil, &PolicyError{ + Typ: InternalFailure, + Err: err, + } + } + + return p, nil +} + +func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { + return nil, err + } + + if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil { + return nil, &PolicyError{ + Typ: StoreFailure, + Err: err, + } + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return nil, &PolicyError{ + Typ: ReloadFailure, + Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err), + } + } + + return p, nil +} + +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { + return nil, err + } + + if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil { + return nil, &PolicyError{ + Typ: StoreFailure, + Err: err, + } + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return nil, &PolicyError{ + Typ: ReloadFailure, + Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err), + } + } + + return p, nil +} + +func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil { + return &PolicyError{ + Typ: StoreFailure, + Err: err, + } + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return &PolicyError{ + Typ: ReloadFailure, + Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err), + } + } + + return nil +} + +func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error { + + // no policy and thus nothing to evaluate; return early + if p == nil { + return nil + } + + // get all current admins from the database + allAdmins, err := a.adminDB.GetAdmins(ctx) + if err != nil { + return &PolicyError{ + Typ: InternalFailure, + Err: fmt.Errorf("error retrieving admins: %w", err), + } + } + + return a.checkPolicy(ctx, currentAdmin, allAdmins, p) +} + +func (a *Authority) checkProvisionerPolicy(ctx context.Context, provName string, p *linkedca.Policy) error { + + // no policy and thus nothing to evaluate; return early + if p == nil { + return nil + } + + // get all admins for the provisioner; ignoring case in which they're not found + allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName) + + // check the policy; pass in nil as the current admin, as all admins for the + // provisioner will be checked by looping through allProvisionerAdmins. Also, + // the current admin may be a super admin not belonging to the provisioner, so + // can't be blocked, but is not required to be in the policy, either. + return a.checkPolicy(ctx, nil, allProvisionerAdmins, p) +} + +// checkPolicy checks if a new or updated policy configuration results in the user +// locking themselves or other admins out of the CA. +func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error { + + // convert the policy; return early if nil + policyOptions := authPolicy.LinkedToCertificates(p) + if policyOptions == nil { + return nil + } + + engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options()) + if err != nil { + return &PolicyError{ + Typ: ConfigurationFailure, + Err: err, + } + } + + // when an empty X.509 policy is provided, the resulting engine is nil + // and there's no policy to evaluate. + if engine == nil { + return nil + } + + // TODO(hs): Provide option to force the policy, even when the admin subject would be locked out? + + // check if the admin user that instructed the authority policy to be + // created or updated, would still be allowed when the provided policy + // would be applied. This case is skipped when current admin is nil, which + // is the case when a provisioner policy is checked. + if currentAdmin != nil { + sans := []string{currentAdmin.GetSubject()} + if err := isAllowed(engine, sans); err != nil { + return err + } + } + + // loop through admins to verify that none of them would be + // locked out when the new policy were to be applied. Returns + // an error with a message that includes the admin subject that + // would be locked out. + for _, adm := range otherAdmins { + sans := []string{adm.GetSubject()} + if err := isAllowed(engine, sans); err != nil { + return err + } + } + + // TODO(hs): mask the error message for non-super admins? + + return nil +} + +// reloadPolicyEngines reloads x509 and SSH policy engines using +// configuration stored in the DB or from the configuration file. +func (a *Authority) reloadPolicyEngines(ctx context.Context) error { + var ( + err error + policyOptions *authPolicy.Options + ) + + if a.config.AuthorityConfig.EnableAdmin { + + // temporarily disable policy loading when LinkedCA is in use + if _, ok := a.adminDB.(*linkedCaClient); ok { + return nil + } + + linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx) + if err != nil { + var ae *admin.Error + if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError { + return fmt.Errorf("error getting policy to (re)load policy engines: %w", err) + } + } + policyOptions = authPolicy.LinkedToCertificates(linkedPolicy) + } else { + policyOptions = a.config.AuthorityConfig.Policy + } + + engine, err := authPolicy.New(policyOptions) + if err != nil { + return err + } + + // only update the policy engine when no error was returned + a.policyEngine = engine + + return nil +} + +func isAllowed(engine authPolicy.X509Policy, sans []string) error { + if err := engine.AreSANsAllowed(sans); err != nil { + var policyErr *policy.NamePolicyError + isNamePolicyError := errors.As(err, &policyErr) + if isNamePolicyError && policyErr.Reason == policy.NotAllowed { + return &PolicyError{ + Typ: AdminLockOut, + Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans), + } + } + return &PolicyError{ + Typ: EvaluationFailure, + Err: err, + } + } + + return nil +} diff --git a/authority/policy/engine.go b/authority/policy/engine.go new file mode 100644 index 00000000..4b21f66b --- /dev/null +++ b/authority/policy/engine.go @@ -0,0 +1,114 @@ +package policy + +import ( + "crypto/x509" + "errors" + "fmt" + + "golang.org/x/crypto/ssh" +) + +// Engine is a container for multiple policies. +type Engine struct { + x509Policy X509Policy + sshUserPolicy UserPolicy + sshHostPolicy HostPolicy +} + +// New returns a new Engine using Options. +func New(options *Options) (*Engine, error) { + + // if no options provided, return early + if options == nil { + return nil, nil + } + + var ( + x509Policy X509Policy + sshHostPolicy HostPolicy + sshUserPolicy UserPolicy + err error + ) + + // initialize the x509 allow/deny policy engine + if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil { + return nil, err + } + + // initialize the SSH allow/deny policy engine for host certificates + if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + // initialize the SSH allow/deny policy engine for user certificates + if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + return &Engine{ + x509Policy: x509Policy, + sshHostPolicy: sshHostPolicy, + sshUserPolicy: sshUserPolicy, + }, nil +} + +// IsX509CertificateAllowed evaluates an X.509 certificate against +// the X.509 policy (if available) and returns an error if one of the +// names in the certificate is not allowed. +func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error { + + // return early if there's no policy to evaluate + if e == nil || e.x509Policy == nil { + return nil + } + + // return result of X.509 policy evaluation + return e.x509Policy.IsX509CertificateAllowed(cert) +} + +// AreSANsAllowed evaluates the slice of SANs against the X.509 policy +// (if available) and returns an error if one of the SANs is not allowed. +func (e *Engine) AreSANsAllowed(sans []string) error { + + // return early if there's no policy to evaluate + if e == nil || e.x509Policy == nil { + return nil + } + + // return result of X.509 policy evaluation + return e.x509Policy.AreSANsAllowed(sans) +} + +// IsSSHCertificateAllowed evaluates an SSH certificate against the +// user or host policy (if configured) and returns an error if one of the +// principals in the certificate is not allowed. +func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error { + + // return early if there's no policy to evaluate + if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) { + return nil + } + + switch cert.CertType { + case ssh.HostCert: + // when no host policy engine is configured, but a user policy engine is + // configured, the host certificate is denied. + if e.sshHostPolicy == nil && e.sshUserPolicy != nil { + return errors.New("authority not allowed to sign ssh host certificates") + } + + // return result of SSH host policy evaluation + return e.sshHostPolicy.IsSSHCertificateAllowed(cert) + case ssh.UserCert: + // when no user policy engine is configured, but a host policy engine is + // configured, the user certificate is denied. + if e.sshUserPolicy == nil && e.sshHostPolicy != nil { + return errors.New("authority not allowed to sign ssh user certificates") + } + + // return result of SSH user policy evaluation + return e.sshUserPolicy.IsSSHCertificateAllowed(cert) + default: + return fmt.Errorf("unexpected ssh certificate type %q", cert.CertType) + } +} diff --git a/authority/policy/options.go b/authority/policy/options.go new file mode 100644 index 00000000..b93d2cd1 --- /dev/null +++ b/authority/policy/options.go @@ -0,0 +1,194 @@ +package policy + +// Options is a container for authority level x509 and SSH +// policy configuration. +type Options struct { + X509 *X509PolicyOptions `json:"x509,omitempty"` + SSH *SSHPolicyOptions `json:"ssh,omitempty"` +} + +// GetX509Options returns the x509 authority level policy +// configuration +func (o *Options) GetX509Options() *X509PolicyOptions { + if o == nil { + return nil + } + return o.X509 +} + +// GetSSHOptions returns the SSH authority level policy +// configuration +func (o *Options) GetSSHOptions() *SSHPolicyOptions { + if o == nil { + return nil + } + return o.SSH +} + +// X509PolicyOptionsInterface is an interface for providers +// of x509 allowed and denied names. +type X509PolicyOptionsInterface interface { + GetAllowedNameOptions() *X509NameOptions + GetDeniedNameOptions() *X509NameOptions + AreWildcardNamesAllowed() bool +} + +// X509PolicyOptions is a container for x509 allowed and denied +// names. +type X509PolicyOptions struct { + // AllowedNames contains the x509 allowed names + AllowedNames *X509NameOptions `json:"allow,omitempty"` + + // DeniedNames contains the x509 denied names + DeniedNames *X509NameOptions `json:"deny,omitempty"` + + // AllowWildcardNames indicates if literal wildcard names + // like *.example.com are allowed. Defaults to false. + AllowWildcardNames bool `json:"allowWildcardNames,omitempty"` +} + +// X509NameOptions models the X509 name policy configuration. +type X509NameOptions struct { + CommonNames []string `json:"cn,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +// HasNames checks if the AllowedNameOptions has one or more +// names configured. +func (o *X509NameOptions) HasNames() bool { + return len(o.CommonNames) > 0 || + len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.URIDomains) > 0 +} + +// GetAllowedNameOptions returns x509 allowed name policy configuration +func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the x509 denied name policy configuration +func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// AreWildcardNamesAllowed returns whether the authority allows +// literal wildcard names to be signed. +func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool { + if o == nil { + return true + } + return o.AllowWildcardNames +} + +// SSHPolicyOptionsInterface is an interface for providers of +// SSH user and host name policy configuration. +type SSHPolicyOptionsInterface interface { + GetAllowedUserNameOptions() *SSHNameOptions + GetDeniedUserNameOptions() *SSHNameOptions + GetAllowedHostNameOptions() *SSHNameOptions + GetDeniedHostNameOptions() *SSHNameOptions +} + +// SSHPolicyOptions is a container for SSH user and host policy +// configuration +type SSHPolicyOptions struct { + // User contains SSH user certificate options. + User *SSHUserCertificateOptions `json:"user,omitempty"` + // Host contains SSH host certificate options. + Host *SSHHostCertificateOptions `json:"host,omitempty"` +} + +// GetAllowedUserNameOptions returns the SSH allowed user name policy +// configuration. +func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions { + if o == nil || o.User == nil { + return nil + } + return o.User.AllowedNames +} + +// GetDeniedUserNameOptions returns the SSH denied user name policy +// configuration. +func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions { + if o == nil || o.User == nil { + return nil + } + return o.User.DeniedNames +} + +// GetAllowedHostNameOptions returns the SSH allowed host name policy +// configuration. +func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions { + if o == nil || o.Host == nil { + return nil + } + return o.Host.AllowedNames +} + +// GetDeniedHostNameOptions returns the SSH denied host name policy +// configuration. +func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions { + if o == nil || o.Host == nil { + return nil + } + return o.Host.DeniedNames +} + +// SSHUserCertificateOptions is a collection of SSH user certificate options. +type SSHUserCertificateOptions struct { + // AllowedNames contains the names the provisioner is authorized to sign + AllowedNames *SSHNameOptions `json:"allow,omitempty"` + // DeniedNames contains the names the provisioner is not authorized to sign + DeniedNames *SSHNameOptions `json:"deny,omitempty"` +} + +// SSHHostCertificateOptions is a collection of SSH host certificate options. +// It's an alias of SSHUserCertificateOptions, as the options are the same +// for both types of certificates. +type SSHHostCertificateOptions SSHUserCertificateOptions + +// SSHNameOptions models the SSH name policy configuration. +type SSHNameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the +// names that a provisioner is authorized to sign SSH certificates for. +func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the +// names that a provisioner is NOT authorized to sign SSH certificates for. +func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// HasNames checks if the SSHNameOptions has one or more +// names configured. +func (o *SSHNameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.Principals) > 0 +} diff --git a/authority/policy/options_test.go b/authority/policy/options_test.go new file mode 100644 index 00000000..0fd6e7c6 --- /dev/null +++ b/authority/policy/options_test.go @@ -0,0 +1,45 @@ +package policy + +import ( + "testing" +) + +func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) { + tests := []struct { + name string + options *X509PolicyOptions + want bool + }{ + { + name: "nil-options", + options: nil, + want: true, + }, + { + name: "not-set", + options: &X509PolicyOptions{}, + want: false, + }, + { + name: "set-true", + options: &X509PolicyOptions{ + AllowWildcardNames: true, + }, + want: true, + }, + { + name: "set-false", + options: &X509PolicyOptions{ + AllowWildcardNames: false, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.options.AreWildcardNamesAllowed(); got != tt.want { + t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/policy/policy.go b/authority/policy/policy.go new file mode 100644 index 00000000..3c53b704 --- /dev/null +++ b/authority/policy/policy.go @@ -0,0 +1,256 @@ +package policy + +import ( + "fmt" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/policy" +) + +// X509Policy is an alias for policy.X509NamePolicyEngine +type X509Policy policy.X509NamePolicyEngine + +// UserPolicy is an alias for policy.SSHNamePolicyEngine +type UserPolicy policy.SSHNamePolicyEngine + +// HostPolicy is an alias for policy.SSHNamePolicyEngine +type HostPolicy policy.SSHNamePolicyEngine + +// NewX509PolicyEngine creates a new x509 name policy engine +func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) { + + // return early if no policy engine options to configure + if policyOptions == nil { + return nil, nil + } + + options := []policy.NamePolicyOption{} + + allowed := policyOptions.GetAllowedNameOptions() + if allowed != nil && allowed.HasNames() { + options = append(options, + policy.WithPermittedCommonNames(allowed.CommonNames...), + policy.WithPermittedDNSDomains(allowed.DNSDomains...), + policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...), + policy.WithPermittedEmailAddresses(allowed.EmailAddresses...), + policy.WithPermittedURIDomains(allowed.URIDomains...), + ) + } + + denied := policyOptions.GetDeniedNameOptions() + if denied != nil && denied.HasNames() { + options = append(options, + policy.WithExcludedCommonNames(denied.CommonNames...), + policy.WithExcludedDNSDomains(denied.DNSDomains...), + policy.WithExcludedIPsOrCIDRs(denied.IPRanges...), + policy.WithExcludedEmailAddresses(denied.EmailAddresses...), + policy.WithExcludedURIDomains(denied.URIDomains...), + ) + } + + // ensure no policy engine is returned when no name options were provided + if len(options) == 0 { + return nil, nil + } + + // check if configuration specifies that wildcard names are allowed + if policyOptions.AreWildcardNamesAllowed() { + options = append(options, policy.WithAllowLiteralWildcardNames()) + } + + // enable subject common name verification by default + options = append(options, policy.WithSubjectCommonNameVerification()) + + return policy.New(options...) +} + +type sshPolicyEngineType string + +const ( + UserPolicyEngineType sshPolicyEngineType = "user" + HostPolicyEngineType sshPolicyEngineType = "host" +) + +// newSSHUserPolicyEngine creates a new SSH user certificate policy engine +func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) { + policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType) + if err != nil { + return nil, err + } + return policyEngine, nil +} + +// newSSHHostPolicyEngine create a new SSH host certificate policy engine +func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) { + policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType) + if err != nil { + return nil, err + } + return policyEngine, nil +} + +// newSSHPolicyEngine creates a new SSH name policy engine +func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) { + + // return early if no policy engine options to configure + if policyOptions == nil { + return nil, nil + } + + var ( + allowed *SSHNameOptions + denied *SSHNameOptions + ) + + switch typ { + case UserPolicyEngineType: + allowed = policyOptions.GetAllowedUserNameOptions() + denied = policyOptions.GetDeniedUserNameOptions() + case HostPolicyEngineType: + allowed = policyOptions.GetAllowedHostNameOptions() + denied = policyOptions.GetDeniedHostNameOptions() + default: + return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ) + } + + options := []policy.NamePolicyOption{} + + if allowed != nil && allowed.HasNames() { + options = append(options, + policy.WithPermittedDNSDomains(allowed.DNSDomains...), + policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...), + policy.WithPermittedEmailAddresses(allowed.EmailAddresses...), + policy.WithPermittedPrincipals(allowed.Principals...), + ) + } + + if denied != nil && denied.HasNames() { + options = append(options, + policy.WithExcludedDNSDomains(denied.DNSDomains...), + policy.WithExcludedIPsOrCIDRs(denied.IPRanges...), + policy.WithExcludedEmailAddresses(denied.EmailAddresses...), + policy.WithExcludedPrincipals(denied.Principals...), + ) + } + + // ensure no policy engine is returned when no name options were provided + if len(options) == 0 { + return nil, nil + } + + return policy.New(options...) +} + +func LinkedToCertificates(p *linkedca.Policy) *Options { + + // return early + if p == nil { + return nil + } + + // return early if x509 nor SSH is set + if p.GetX509() == nil && p.GetSsh() == nil { + return nil + } + + opts := &Options{} + + // fill x509 policy configuration + if x509 := p.GetX509(); x509 != nil { + opts.X509 = &X509PolicyOptions{} + if allow := x509.GetAllow(); allow != nil { + opts.X509.AllowedNames = &X509NameOptions{} + if allow.Dns != nil { + opts.X509.AllowedNames.DNSDomains = allow.Dns + } + if allow.Ips != nil { + opts.X509.AllowedNames.IPRanges = allow.Ips + } + if allow.Emails != nil { + opts.X509.AllowedNames.EmailAddresses = allow.Emails + } + if allow.Uris != nil { + opts.X509.AllowedNames.URIDomains = allow.Uris + } + if allow.CommonNames != nil { + opts.X509.AllowedNames.CommonNames = allow.CommonNames + } + } + if deny := x509.GetDeny(); deny != nil { + opts.X509.DeniedNames = &X509NameOptions{} + if deny.Dns != nil { + opts.X509.DeniedNames.DNSDomains = deny.Dns + } + if deny.Ips != nil { + opts.X509.DeniedNames.IPRanges = deny.Ips + } + if deny.Emails != nil { + opts.X509.DeniedNames.EmailAddresses = deny.Emails + } + if deny.Uris != nil { + opts.X509.DeniedNames.URIDomains = deny.Uris + } + if deny.CommonNames != nil { + opts.X509.DeniedNames.CommonNames = deny.CommonNames + } + } + + opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames() + } + + // fill ssh policy configuration + if ssh := p.GetSsh(); ssh != nil { + opts.SSH = &SSHPolicyOptions{} + if host := ssh.GetHost(); host != nil { + opts.SSH.Host = &SSHHostCertificateOptions{} + if allow := host.GetAllow(); allow != nil { + opts.SSH.Host.AllowedNames = &SSHNameOptions{} + if allow.Dns != nil { + opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns + } + if allow.Ips != nil { + opts.SSH.Host.AllowedNames.IPRanges = allow.Ips + } + if allow.Principals != nil { + opts.SSH.Host.AllowedNames.Principals = allow.Principals + } + } + if deny := host.GetDeny(); deny != nil { + opts.SSH.Host.DeniedNames = &SSHNameOptions{} + if deny.Dns != nil { + opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns + } + if deny.Ips != nil { + opts.SSH.Host.DeniedNames.IPRanges = deny.Ips + } + if deny.Principals != nil { + opts.SSH.Host.DeniedNames.Principals = deny.Principals + } + } + } + if user := ssh.GetUser(); user != nil { + opts.SSH.User = &SSHUserCertificateOptions{} + if allow := user.GetAllow(); allow != nil { + opts.SSH.User.AllowedNames = &SSHNameOptions{} + if allow.Emails != nil { + opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails + } + if allow.Principals != nil { + opts.SSH.User.AllowedNames.Principals = allow.Principals + } + } + if deny := user.GetDeny(); deny != nil { + opts.SSH.User.DeniedNames = &SSHNameOptions{} + if deny.Emails != nil { + opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails + } + if deny.Principals != nil { + opts.SSH.User.DeniedNames.Principals = deny.Principals + } + } + } + } + + return opts +} diff --git a/authority/policy/policy_test.go b/authority/policy/policy_test.go new file mode 100644 index 00000000..9210ad90 --- /dev/null +++ b/authority/policy/policy_test.go @@ -0,0 +1,155 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "go.step.sm/linkedca" +) + +func TestPolicyToCertificates(t *testing.T) { + type args struct { + policy *linkedca.Policy + } + tests := []struct { + name string + args args + want *Options + }{ + { + name: "nil", + args: args{ + policy: nil, + }, + want: nil, + }, + { + name: "no-policy", + args: args{ + &linkedca.Policy{}, + }, + want: nil, + }, + { + name: "partial-policy", + args: args{ + &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + AllowWildcardNames: false, + }, + }, + }, + want: &Options{ + X509: &X509PolicyOptions{ + AllowedNames: &X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + AllowWildcardNames: false, + }, + }, + }, + { + name: "full-policy", + args: args{ + &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step"}, + Ips: []string{"127.0.0.1/24"}, + Emails: []string{"*.example.com"}, + Uris: []string{"https://*.local"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad"}, + Ips: []string{"127.0.0.30"}, + Emails: []string{"badhost.example.com"}, + Uris: []string{"https://badhost.local"}, + CommonNames: []string{"another name"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.localhost"}, + Ips: []string{"127.0.0.1/24"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.localhost"}, + Ips: []string{"127.0.0.40"}, + Principals: []string{"root"}, + }, + }, + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@work"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@work"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + want: &Options{ + X509: &X509PolicyOptions{ + AllowedNames: &X509NameOptions{ + DNSDomains: []string{"step"}, + IPRanges: []string{"127.0.0.1/24"}, + EmailAddresses: []string{"*.example.com"}, + URIDomains: []string{"https://*.local"}, + CommonNames: []string{"some name"}, + }, + DeniedNames: &X509NameOptions{ + DNSDomains: []string{"bad"}, + IPRanges: []string{"127.0.0.30"}, + EmailAddresses: []string{"badhost.example.com"}, + URIDomains: []string{"https://badhost.local"}, + CommonNames: []string{"another name"}, + }, + AllowWildcardNames: true, + }, + SSH: &SSHPolicyOptions{ + Host: &SSHHostCertificateOptions{ + AllowedNames: &SSHNameOptions{ + DNSDomains: []string{"*.localhost"}, + IPRanges: []string{"127.0.0.1/24"}, + Principals: []string{"user"}, + }, + DeniedNames: &SSHNameOptions{ + DNSDomains: []string{"badhost.localhost"}, + IPRanges: []string{"127.0.0.40"}, + Principals: []string{"root"}, + }, + }, + User: &SSHUserCertificateOptions{ + AllowedNames: &SSHNameOptions{ + EmailAddresses: []string{"@work"}, + Principals: []string{"user"}, + }, + DeniedNames: &SSHNameOptions{ + EmailAddresses: []string{"root@work"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := LinkedToCertificates(tt.args.policy) + if !cmp.Equal(tt.want, got) { + t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got)) + } + }) + } +} diff --git a/authority/policy_test.go b/authority/policy_test.go new file mode 100644 index 00000000..1dccf0d1 --- /dev/null +++ b/authority/policy_test.go @@ -0,0 +1,1625 @@ +package authority + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/administrator" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" +) + +func TestAuthority_checkPolicy(t *testing.T) { + type test struct { + ctx context.Context + currentAdmin *linkedca.Admin + otherAdmins []*linkedca.Admin + policy *linkedca.Policy + err *PolicyError + } + tests := map[string]func(t *testing.T) test{ + "fail/NewX509PolicyEngine-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"**.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: ConfigurationFailure, + Err: errors.New("cannot parse permitted domain constraint \"**.local\": domain constraint \"**.local\" can only have wildcard as starting character"), + }, + } + }, + "fail/currentAdmin-evaluation-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "*"}, + otherAdmins: []*linkedca.Admin{}, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: EvaluationFailure, + Err: errors.New("cannot parse dns domain \"*\""), + }, + } + }, + "fail/currentAdmin-lockout": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "otherAdmin", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: AdminLockOut, + Err: errors.New("the provided policy would lock out [step] from the CA. Please update your policy to include [step] as an allowed name"), + }, + } + }, + "fail/otherAdmins-evaluation-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "other", + }, + { + Subject: "**", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "other", "*.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: EvaluationFailure, + Err: errors.New("cannot parse dns domain \"**\""), + }, + } + }, + "fail/otherAdmins-lockout": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "otherAdmin", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step"}, + }, + }, + }, + err: &PolicyError{ + Typ: AdminLockOut, + Err: errors.New("the provided policy would lock out [otherAdmin] from the CA. Please update your policy to include [otherAdmin] as an allowed name"), + }, + } + }, + "ok/no-policy": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{}, + policy: nil, + } + }, + "ok/empty-policy": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{}, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{}, + }, + }, + }, + } + }, + "ok/policy": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "otherAdmin", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + a := &Authority{} + + err := a.checkPolicy(tc.ctx, tc.currentAdmin, tc.otherAdmins, tc.policy) + + if tc.err == nil { + assert.Nil(t, err) + } else { + assert.IsType(t, &PolicyError{}, err) + + pe, ok := err.(*PolicyError) + assert.True(t, ok) + + assert.Equal(t, tc.err.Typ, pe.Typ) + assert.Equal(t, tc.err.Error(), pe.Error()) + } + }) + } +} + +func mustPolicyEngine(t *testing.T, options *policy.Options) *policy.Engine { + engine, err := policy.New(options) + if err != nil { + t.Fatal(err) + } + return engine +} + +func TestAuthority_reloadPolicyEngines(t *testing.T) { + + existingPolicyEngine, err := policy.New(&policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.hosts.example.com"}, + }, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.hosts.example.com"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@mails.example.com"}, + }, + }, + }, + }) + assert.NoError(t, err) + + newX509Options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + } + + newSSHHostOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + }, + } + + newSSHUserOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + } + + newSSHOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + } + + newOptions := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + } + + newAdminX509Options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + } + + newAdminSSHHostOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + } + + newAdminSSHUserOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@example.com"}, + }, + }, + }, + } + + newAdminOptions := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@example.com"}, + }, + DeniedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"baduser@example.com"}, + }, + }, + }, + } + + tests := []struct { + name string + config *config.Config + adminDB admin.DB + ctx context.Context + expected *policy.Engine + wantErr bool + }{ + { + name: "fail/standalone-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"**.local"}, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/standalone-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"**.local"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/standalone-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"**example.com"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/adminDB.GetAuthorityPolicy-error", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/admin-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"**.local"}, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/admin-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"**.local"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/admin-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@@example.com"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "ok/linkedca-unsupported", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &linkedCaClient{}, + ctx: context.Background(), + wantErr: false, + expected: existingPolicyEngine, + }, + { + name: "ok/standalone-no-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: nil, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, nil), + }, + { + name: "ok/standalone-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newX509Options), + }, + { + name: "ok/standalone-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newSSHHostOptions), + }, + { + name: "ok/standalone-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newSSHUserOptions), + }, + { + name: "ok/standalone-ssh-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newSSHOptions), + }, + { + name: "ok/standalone-full-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newOptions), + }, + { + name: "ok/admin-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newAdminX509Options), + }, + { + name: "ok/admin-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newAdminSSHHostOptions), + }, + { + name: "ok/admin-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newAdminSSHUserOptions), + }, + { + name: "ok/admin-full-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + ctx: context.Background(), + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + }, + }, + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"baduser@example.com"}, + }, + }, + }, + }, nil + }, + }, + wantErr: false, + expected: mustPolicyEngine(t, newAdminOptions), + }, + { + // both DB and JSON config; DB config is taken if Admin API is enabled + name: "ok/admin-over-standalone", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + }, nil + }, + }, + wantErr: false, + expected: mustPolicyEngine(t, newX509Options), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.config, + adminDB: tt.adminDB, + policyEngine: existingPolicyEngine, + } + if err := a.reloadPolicyEngines(tt.ctx); (err != nil) != tt.wantErr { + t.Errorf("Authority.reloadPolicyEngines() error = %v, wantErr %v", err, tt.wantErr) + } + + assert.Equal(t, tt.expected, a.policyEngine) + }) + } +} + +func TestAuthority_checkAuthorityPolicy(t *testing.T) { + type fields struct { + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + currentAdmin *linkedca.Admin + provName string + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "no policy", + fields: fields{}, + args: args{ + currentAdmin: nil, + provName: "prov", + p: nil, + }, + wantErr: false, + }, + { + name: "fail/adminDB.GetAdmins-error", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "fail/policy", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + { + Id: "adminID1", + Subject: "anotherAdmin", + }, + { + Id: "adminID2", + Subject: "step", + }, + { + Id: "adminID3", + Subject: "otherAdmin", + }, + }, nil + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + { + Id: "adminID2", + Subject: "step", + }, + { + Id: "adminID3", + Subject: "otherAdmin", + }, + }, nil + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + provisioners: tt.fields.provisioners, + admins: tt.fields.admins, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + if err := a.checkAuthorityPolicy(tt.args.ctx, tt.args.currentAdmin, tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthority_checkProvisionerPolicy(t *testing.T) { + jwkProvisioner := &provisioner.JWK{ + ID: "jwkID", + Type: "JWK", + Name: "jwkProv", + Key: &jose.JSONWebKey{KeyID: "jwkKeyID"}, + } + provisioners := provisioner.NewCollection(testAudiences) + provisioners.Store(jwkProvisioner) + admins := administrator.NewCollection(provisioners) + admins.Store(&linkedca.Admin{ + Id: "adminID", + Subject: "step", + ProvisionerId: "jwkID", + }, jwkProvisioner) + type fields struct { + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + provName string + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "no policy", + fields: fields{}, + args: args{ + provName: "prov", + p: nil, + }, + wantErr: false, + }, + { + name: "fail/policy", + fields: fields{ + provisioners: provisioners, + admins: admins, + }, + args: args{ + provName: "jwkProv", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"otherAdmin"}, // step not in policy + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok", + fields: fields{ + provisioners: provisioners, + admins: admins, + }, + args: args{ + provName: "jwkProv", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + provisioners: tt.fields.provisioners, + admins: tt.fields.admins, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + if err := a.checkProvisionerPolicy(tt.args.ctx, tt.args.provName, tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthority_RemoveAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + wantErr *PolicyError + }{ + { + name: "fail/adminDB.DeleteAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when deleting authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + err := a.RemoveAuthorityPolicy(tt.args.ctx) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + }) + } +} + +func TestAuthority_GetAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/adminDB.GetAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{}, nil + }, + }, + }, + want: &linkedca.Policy{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.GetAuthorityPolicy(tt.args.ctx) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_CreateAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + adm *linkedca.Admin + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/a.checkAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("error retrieving admins: force"), + }, + }, + { + name: "fail/adminDB.CreateAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + MockCreateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when creating authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, nil + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.CreateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.CreateAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_UpdateAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + adm *linkedca.Admin + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/a.checkAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("error retrieving admins: force"), + }, + }, + { + name: "fail/adminDB.UpdateAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when updating authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return nil + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.UpdateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.UpdateAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 913d0ace..9374d985 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -3,6 +3,8 @@ package provisioner import ( "context" "crypto/x509" + "fmt" + "net" "time" "github.com/pkg/errors" @@ -23,7 +25,8 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - ctl *Controller + + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -71,7 +74,7 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration { return p.ctl.Claimer.DefaultTLSCertDuration() } -// Init initializes and validates the fields of a JWK type. +// Init initializes and validates the fields of an ACME type. func (p *ACME) Init(config Config) (err error) { switch { case p.Type == "": @@ -80,15 +83,57 @@ func (p *ACME) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } +// ACMEIdentifierType encodes ACME Identifier types +type ACMEIdentifierType string + +const ( + // IP is the ACME ip identifier type + IP ACMEIdentifierType = "ip" + // DNS is the ACME dns identifier type + DNS ACMEIdentifierType = "dns" +) + +// ACMEIdentifier encodes ACME Order Identifiers +type ACMEIdentifier struct { + Type ACMEIdentifierType + Value string +} + +// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a +// certificate for an ACME Order Identifier. +func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error { + + x509Policy := p.ctl.getPolicy().getX509() + + // identifier is allowed if no policy is configured + if x509Policy == nil { + return nil + } + + // assuming only valid identifiers (IP or DNS) are provided + var err error + switch identifier.Type { + case IP: + err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value)) + case DNS: + err = x509Policy.IsDNSAllowed(identifier.Value) + default: + err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type) + } + + return err +} + // AuthorizeSign does not do any validation, because all validation is handled // in the ACME protocol. This method returns a list of modifiers / constraints // on the resulting certificate. func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return []SignOption{ + opts := []SignOption{ + p, // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), newForceCNOption(p.ForceCN), @@ -96,7 +141,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), - }, nil + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), + } + + return opts, nil } // AuthorizeRevoke is called just before the certificate is to be revoked by diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 49ac9468..33cbbc75 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -176,9 +176,10 @@ func TestACME_AuthorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { - assert.Len(t, 5, opts) + assert.Equals(t, 7, len(opts)) // number of SignOptions returned for _, o := range opts { switch v := o.(type) { + case *ACME: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeACME) assert.Equals(t, v.Name, tc.p.GetName()) @@ -192,6 +193,8 @@ func TestACME_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 5f79d7d0..a5b403a4 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -17,10 +17,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // awsIssuer is the string used as issuer in the generated tokens. @@ -49,22 +51,27 @@ const awsMetadataTokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds" // signature. // // The first certificate is used in: -// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2 -// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3 -// us-east-1, us-east-2, us-west-1, us-west-2 -// ca-central-1, sa-east-1 +// +// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2 +// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3 +// us-east-1, us-east-2, us-west-1, us-west-2 +// ca-central-1, sa-east-1 // // The second certificate is used in: -// eu-south-1 +// +// eu-south-1 // // The third certificate is used in: -// ap-east-1 +// +// ap-east-1 // // The fourth certificate is used in: -// af-south-1 +// +// af-south-1 // // The fifth certificate is used in: -// me-south-1 +// +// me-south-1 const awsCertificate = `-----BEGIN CERTIFICATE----- MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw @@ -421,7 +428,7 @@ func (p *AWS) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -467,6 +474,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return append(so, + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), @@ -475,6 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), ), nil } @@ -542,7 +551,7 @@ func (p *AWS) readURL(url string) ([]byte, error) { if err != nil { return nil, err } - return nil, fmt.Errorf("Request for metadata returned non-successful status code %d", + return nil, fmt.Errorf("request for metadata returned non-successful status code %d", resp.StatusCode) } @@ -575,7 +584,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) { } defer resp.Body.Close() if resp.StatusCode >= 400 { - return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) + return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode) } token, err := io.ReadAll(resp.Body) if err != nil { @@ -743,6 +752,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, signOptions = append(signOptions, templateOptions) return append(signOptions, + p, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. @@ -753,5 +763,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 1b7efa7c..d12d0626 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -642,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false}, - {"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false}, - {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false}, - {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false}, - {"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false}, + {"ok", p1, args{t1, "foo.local"}, 8, http.StatusOK, false}, + {"ok", p2, args{t2, "instance-id"}, 12, http.StatusOK, false}, + {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 12, http.StatusOK, false}, + {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 12, http.StatusOK, false}, + {"ok", p1, args{t4, "instance-id"}, 8, http.StatusOK, false}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, @@ -673,9 +673,10 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: - assert.Len(t, tt.wantLen, got) + assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { + case *AWS: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeAWS) @@ -698,6 +699,8 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } @@ -810,7 +813,6 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { - t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 93ef70e5..b6f7ec91 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -13,10 +13,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. @@ -219,7 +221,7 @@ func (p *Azure) Init(config Config) (err error) { return } - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -352,6 +354,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } return append(so, + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), @@ -359,6 +362,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), ), nil } @@ -414,6 +418,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio signOptions = append(signOptions, templateOptions) return append(signOptions, + p, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. @@ -424,6 +429,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 8002563c..3e745a5b 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -474,11 +474,11 @@ func TestAzure_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 10, http.StatusOK, false}, - {"ok", p1, args{t11}, 5, http.StatusOK, false}, - {"ok", p5, args{t5}, 5, http.StatusOK, false}, - {"ok", p7, args{t7}, 5, http.StatusOK, false}, + {"ok", p1, args{t1}, 7, http.StatusOK, false}, + {"ok", p2, args{t2}, 12, http.StatusOK, false}, + {"ok", p1, args{t11}, 7, http.StatusOK, false}, + {"ok", p5, args{t5}, 7, http.StatusOK, false}, + {"ok", p7, args{t7}, 7, http.StatusOK, false}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, {"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true}, @@ -502,9 +502,10 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: - assert.Len(t, tt.wantLen, got) + assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { + case *Azure: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeAzure) @@ -527,6 +528,8 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 2a3e2c61..96f19b37 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -24,8 +24,8 @@ type Claims struct { EnableSSHCA *bool `json:"enableSSHCA,omitempty"` // Renewal properties - DisableRenewal *bool `json:"disableRenewal,omitempty"` - AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"` + DisableRenewal *bool `json:"disableRenewal,omitempty"` + AllowRenewalAfterExpiry *bool `json:"allowRenewalAfterExpiry,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -44,22 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() - allowRenewAfterExpiry := c.AllowRenewAfterExpiry() + allowRenewalAfterExpiry := c.AllowRenewalAfterExpiry() enableSSHCA := c.IsSSHCAEnabled() return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, - MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, - DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, - MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, - MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, - DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, - EnableSSHCA: &enableSSHCA, - DisableRenewal: &disableRenewal, - AllowRenewAfterExpiry: &allowRenewAfterExpiry, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, + EnableSSHCA: &enableSSHCA, + DisableRenewal: &disableRenewal, + AllowRenewalAfterExpiry: &allowRenewalAfterExpiry, } } @@ -109,14 +109,14 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } -// AllowRenewAfterExpiry returns if the renewal flow is authorized if the +// AllowRenewalAfterExpiry returns if the renewal flow is authorized if the // certificate is expired. If the property is not set within the provisioner // then the global value from the authority configuration will be used. -func (c *Claimer) AllowRenewAfterExpiry() bool { - if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil { - return *c.global.AllowRenewAfterExpiry +func (c *Claimer) AllowRenewalAfterExpiry() bool { + if c.claims == nil || c.claims.AllowRenewalAfterExpiry == nil { + return *c.global.AllowRenewalAfterExpiry } - return *c.claims.AllowRenewAfterExpiry + return *c.claims.AllowRenewalAfterExpiry } // DefaultSSHCertDuration returns the default SSH certificate duration for the diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index a91ebaac..063ab50c 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -3,6 +3,7 @@ package provisioner import ( "context" "crypto/x509" + "net/http" "regexp" "strings" "time" @@ -21,14 +22,19 @@ type Controller struct { IdentityFunc GetIdentityFunc AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc + policy *policyEngine } // NewController initializes a new provisioner controller. -func NewController(p Interface, claims *Claims, config Config) (*Controller, error) { +func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) { claimer, err := NewClaimer(claims, config.Claims) if err != nil { return nil, err } + policy, err := newPolicyEngine(options) + if err != nil { + return nil, err + } return &Controller{ Interface: p, Audiences: &config.Audiences, @@ -36,6 +42,7 @@ func NewController(p Interface, claims *Claims, config Config) (*Controller, err IdentityFunc: config.GetIdentityFunc, AuthorizeRenewFunc: config.AuthorizeRenewFunc, AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, + policy: policy, }, nil } @@ -124,8 +131,10 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif if now.Before(cert.NotBefore) { return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) } - if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() { - return errs.Unauthorized("certificate has expired") + if now.After(cert.NotAfter) && !p.Claimer.AllowRenewalAfterExpiry() { + // return a custom 401 Unauthorized error with a clearer message for the client + // TODO(hs): these errors likely need to be refactored as a whole; HTTP status codes shouldn't be in this layer. + return errs.New(http.StatusUnauthorized, "The request lacked necessary authorization to be completed: certificate expired on %s", cert.NotAfter) } return nil @@ -144,7 +153,7 @@ func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Cert if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { return errs.Unauthorized("certificate is not yet valid") } - if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() { + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() { return errs.Unauthorized("certificate has expired") } @@ -192,3 +201,10 @@ func SanitizeSSHUserPrincipal(email string) string { } }, strings.ToLower(email)) } + +func (c *Controller) getPolicy() *policyEngine { + if c == nil { + return nil + } + return c.policy +} diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go index 9fb90e9d..37cbfd89 100644 --- a/authority/provisioner/controller_test.go +++ b/authority/provisioner/controller_test.go @@ -9,6 +9,8 @@ import ( "time" "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/authority/policy" ) var trueValue = true @@ -30,11 +32,40 @@ func mustDuration(t *testing.T, s string) *Duration { return d } +func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine { + t.Helper() + c, err := newPolicyEngine(options) + if err != nil { + t.Fatal(err) + } + return c +} + func TestNewController(t *testing.T) { + options := &Options{ + X509: &X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + SSH: &SSHOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@example.com"}, + }, + }, + }, + } type args struct { - p Interface - claims *Claims - config Config + p Interface + claims *Claims + config Config + options *Options } tests := []struct { name string @@ -45,7 +76,7 @@ func TestNewController(t *testing.T) { {"ok", args{&JWK{}, nil, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, - }}, &Controller{ + }, nil}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, nil, globalProvisionerClaims), @@ -55,24 +86,49 @@ func TestNewController(t *testing.T) { }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, - }}, &Controller{ + }, nil}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, &Claims{ DisableRenewal: &defaultDisableRenewal, }, globalProvisionerClaims), }, false}, + {"ok with claims and options", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }, options}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, globalProvisionerClaims), + policy: mustNewPolicyEngine(t, options), + }, false}, {"fail claimer", args{&JWK{}, &Claims{ MinTLSDur: mustDuration(t, "24h"), MaxTLSDur: mustDuration(t, "2h"), }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, + }, nil}, nil, true}, + {"fail options", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }, &Options{ + X509: &X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"**.local"}, + }, + }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) + got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options) if (err != nil) != tt.wantErr { t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) return @@ -160,13 +216,13 @@ func TestController_AuthorizeRenew(t *testing.T) { NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, - {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { return nil }}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, - {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), }}, false}, @@ -231,13 +287,13 @@ func TestController_AuthorizeSSHRenew(t *testing.T) { ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, - {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { return nil }}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, - {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), }}, false}, @@ -296,7 +352,7 @@ func TestDefaultAuthorizeRenew(t *testing.T) { }}, false}, {"ok renew after expiry", args{ctx, &Controller{ Interface: &JWK{}, - Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), @@ -354,7 +410,7 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) { }}, false}, {"ok renew after expiry", args{ctx, &Controller{ Interface: &JWK{}, - Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 6070b640..a116312d 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -14,10 +14,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // gcpCertsURL is the url that serves Google OAuth2 public keys. @@ -212,7 +214,7 @@ func (p *GCP) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -262,6 +264,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return append(so, + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), @@ -269,6 +272,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), ), nil } @@ -421,6 +425,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, signOptions = append(signOptions, templateOptions) return append(signOptions, + p, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. @@ -431,5 +436,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 4ac42bff..3c0bf92e 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -516,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 10, http.StatusOK, false}, - {"ok", p3, args{t3}, 5, http.StatusOK, false}, + {"ok", p1, args{t1}, 7, http.StatusOK, false}, + {"ok", p2, args{t2}, 12, http.StatusOK, false}, + {"ok", p3, args{t3}, 7, http.StatusOK, false}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, @@ -545,9 +545,10 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: - assert.Len(t, tt.wantLen, got) + assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { + case *GCP: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeGCP) @@ -570,6 +571,8 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index c014bec0..de592941 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -7,10 +7,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // jwtPayload extends jwt.Claims with step attributes. @@ -97,7 +99,7 @@ func (p *JWK) Init(config Config) (err error) { return errors.New("provisioner key cannot be empty") } - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -141,6 +143,7 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) + // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } @@ -170,6 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), @@ -179,6 +183,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -187,6 +192,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { + // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew) return p.ctl.AuthorizeRenew(ctx, cert) } @@ -251,6 +257,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, } return append(signOptions, + p, // Set the validity bounds if not set. &sshDefaultDuration{p.ctl.Claimer}, // Validate public key @@ -259,11 +266,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), ), nil } // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) + // TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 215d9c84..bd8b542b 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -297,9 +297,10 @@ func TestJWK_AuthorizeSign(t *testing.T) { } } else { if assert.NotNil(t, got) { - assert.Len(t, 7, got) + assert.Equals(t, 9, len(got)) for _, o := range got { switch v := o.(type) { + case *JWK: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeJWK) @@ -316,6 +317,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index 557d571a..28be0d5c 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -10,11 +10,13 @@ import ( "net/http" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // NOTE: There can be at most one kubernetes service account provisioner configured @@ -91,6 +93,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) { // Init initializes and validates the fields of a K8sSA type. func (p *K8sSA) Init(config Config) (err error) { + switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -137,7 +140,7 @@ func (p *K8sSA) Init(config Config) (err error) { p.kauthn = k8s.AuthenticationV1() */ - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -231,6 +234,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), @@ -238,6 +242,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -270,6 +275,7 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio signOptions := []SignOption{templateOptions} return append(signOptions, + p, // Require type, key-id and principals in the SignSSHOptions. &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, // Set the validity bounds if not set. @@ -280,6 +286,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), ), nil } diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index b1aa3b55..2458babb 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -280,9 +280,9 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 for _, o := range opts { switch v := o.(type) { + case *K8sSA: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeK8sSA) @@ -295,12 +295,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 5) + assert.Equals(t, 7, len(opts)) } } } @@ -367,9 +368,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 + assert.Len(t, 8, opts) for _, o := range opts { switch v := o.(type) { + case Interface: case sshCertificateOptionsFunc: case *sshCertOptionsRequireValidator: assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) @@ -379,12 +381,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertDefaultValidator: case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) + case *sshNamePolicyValidator: + assert.Equals(t, nil, v.userPolicyEngine) + assert.Equals(t, nil, v.hostPolicyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 6) } } } diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index f5cd5221..01dda2ed 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method { m, _ := ctx.Value(methodKey{}).(Method) return m } + +type tokenKey struct{} + +// NewContextWithToken creates a new context with the given token. +func NewContextWithToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, tokenKey{}, token) +} + +// TokenFromContext returns the token stored in the given context. +func TokenFromContext(ctx context.Context) (string, bool) { + token, ok := ctx.Value(tokenKey{}).(string) + return token, ok +} diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 1a6eee3e..cde5857c 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -10,12 +10,14 @@ import ( "github.com/pkg/errors" nebula "github.com/slackhq/nebula/cert" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x25519" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/errs" ) const ( @@ -61,7 +63,7 @@ func (p *Nebula) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -144,6 +146,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeNebula, p.Name, ""), @@ -160,6 +163,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, }, defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -246,6 +250,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti } return append(signOptions, + p, templateOptions, // Checks the validity bounds, and set the validity if has not been set. &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter}, @@ -255,6 +260,8 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/noop.go b/authority/provisioner/noop.go index 1709fbca..9ccd0c8c 100644 --- a/authority/provisioner/noop.go +++ b/authority/provisioner/noop.go @@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error { } func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return []SignOption{}, nil + return []SignOption{p}, nil } func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { @@ -50,7 +50,7 @@ func (p *noop) AuthorizeRevoke(ctx context.Context, token string) error { } func (p *noop) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - return []SignOption{}, nil + return []SignOption{p}, nil } func (p *noop) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { diff --git a/authority/provisioner/noop_test.go b/authority/provisioner/noop_test.go index 19e4d235..b10d1d29 100644 --- a/authority/provisioner/noop_test.go +++ b/authority/provisioner/noop_test.go @@ -24,6 +24,6 @@ func Test_noop(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) sigOptions, err := p.AuthorizeSign(ctx, "foo") - assert.Equals(t, []SignOption{}, sigOptions) + assert.Equals(t, []SignOption{&p}, sigOptions) assert.Equals(t, nil, err) } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 1fc9bb4b..e64d98d9 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -12,10 +12,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // openIDConfiguration contains the necessary properties in the @@ -195,7 +197,7 @@ func (o *OIDC) Init(config Config) (err error) { return err } - o.ctl, err = NewController(o, o.Claims, config) + o.ctl, err = NewController(o, o.Claims, config, o.Options) return } @@ -345,6 +347,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e } return []SignOption{ + o, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), @@ -352,6 +355,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators defaultPublicKeyValidator{}, newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(o.ctl.getPolicy().getX509()), }, nil } @@ -430,6 +434,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption } return append(signOptions, + o, // Set the validity bounds if not set. &sshDefaultDuration{o.ctl.Claimer}, // Validate public key @@ -438,6 +443,8 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(o.ctl.getPolicy().getSSHHost(), o.ctl.getPolicy().getSSHUser()), ), nil } diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 18b568a7..3d039496 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -323,9 +323,10 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { - assert.Len(t, 5, got) + assert.Equals(t, 7, len(got)) for _, o := range got { switch v := o.(type) { + case *OIDC: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeOIDC) @@ -340,6 +341,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go index f86c4863..f5c919b4 100644 --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -5,8 +5,11 @@ import ( "strings" "github.com/pkg/errors" + "go.step.sm/crypto/jose" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/policy" ) // CertificateOptions is an interface that returns a list of options passed when @@ -56,6 +59,16 @@ type X509Options struct { // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` + + // AllowedNames contains the SANs the provisioner is authorized to sign + AllowedNames *policy.X509NameOptions `json:"-"` + + // DeniedNames contains the SANs the provisioner is not authorized to sign + DeniedNames *policy.X509NameOptions `json:"-"` + + // AllowWildcardNames indicates if literal wildcard names + // like *.example.com are allowed. Defaults to false. + AllowWildcardNames bool `json:"-"` } // HasTemplate returns true if a template is defined in the provisioner options. @@ -63,6 +76,31 @@ func (o *X509Options) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } +// GetAllowedNameOptions returns the AllowedNames, which models the +// SANs that a provisioner is authorized to sign x509 certificates for. +func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedNames, which models the +// SANs that a provisioner is NOT authorized to sign x509 certificates for. +func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +func (o *X509Options) AreWildcardNamesAllowed() bool { + if o == nil { + return true + } + return o.AllowWildcardNames +} + // TemplateOptions generates a CertificateOptions with the template and data // defined in the ProvisionerOptions, the provisioner generated data, and the // user data provided in the request. If no template has been provided, diff --git a/authority/provisioner/options_test.go b/authority/provisioner/options_test.go index 8f411aca..0bcf9ec3 100644 --- a/authority/provisioner/options_test.go +++ b/authority/provisioner/options_test.go @@ -287,3 +287,38 @@ func Test_unsafeParseSigned(t *testing.T) { }) } } + +func TestX509Options_IsWildcardLiteralAllowed(t *testing.T) { + tests := []struct { + name string + options *X509Options + want bool + }{ + { + name: "nil-options", + options: nil, + want: true, + }, + { + name: "set-true", + options: &X509Options{ + AllowWildcardNames: true, + }, + want: true, + }, + { + name: "set-false", + options: &X509Options{ + AllowWildcardNames: false, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.options.AreWildcardNamesAllowed(); got != tt.want { + t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/policy.go b/authority/provisioner/policy.go new file mode 100644 index 00000000..95ef4163 --- /dev/null +++ b/authority/provisioner/policy.go @@ -0,0 +1,65 @@ +package provisioner + +import "github.com/smallstep/certificates/authority/policy" + +type policyEngine struct { + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy +} + +func newPolicyEngine(options *Options) (*policyEngine, error) { + + if options == nil { + return nil, nil + } + + var ( + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy + err error + ) + + // Initialize the x509 allow/deny policy engine + if x509Policy, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil { + return nil, err + } + + // Initialize the SSH allow/deny policy engine for host certificates + if sshHostPolicy, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + // Initialize the SSH allow/deny policy engine for user certificates + if sshUserPolicy, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + return &policyEngine{ + x509Policy: x509Policy, + sshHostPolicy: sshHostPolicy, + sshUserPolicy: sshUserPolicy, + }, nil +} + +func (p *policyEngine) getX509() policy.X509Policy { + if p == nil { + return nil + } + return p.x509Policy +} + +func (p *policyEngine) getSSHHost() policy.HostPolicy { + if p == nil { + return nil + } + return p.sshHostPolicy +} + +func (p *policyEngine) getSSHUser() policy.UserPolicy { + if p == nil { + return nil + } + return p.sshUserPolicy +} diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 7438ea17..0d5cd41a 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/pkg/errors" - "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" ) @@ -212,8 +211,6 @@ type Config struct { Claims Claims // Audiences are the audiences used in the default provisioner, (JWK). Audiences Audiences - // DB is the interface to the authority DB client. - DB db.AuthDB // SSHKeys are the root SSH public keys SSHKeys *SSHKeys // GetIdentityFunc is a function that returns an identity that will be diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index f4cffd78..c49c993e 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -28,13 +28,12 @@ type SCEP struct { // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - secretChallengePassword string - encryptionAlgorithm int - ctl *Controller + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + ctl *Controller + secretChallengePassword string + encryptionAlgorithm int } // GetID returns the provisioner unique identifier. @@ -84,7 +83,6 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration { // Init initializes and validates the fields of a SCEP type. func (s *SCEP) Init(config Config) (err error) { - switch { case s.Type == "": return errors.New("provisioner type cannot be empty") @@ -112,7 +110,7 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? - s.ctl, err = NewController(s, s.Claims, config) + s.ctl, err = NewController(s, s.Claims, config, s.Options) return } @@ -121,6 +119,7 @@ func (s *SCEP) Init(config Config) (err error) { // on the resulting certificate. func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { return []SignOption{ + s, // modifiers / withOptions newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newForceCNOption(s.ForceCN), @@ -128,6 +127,7 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(s.ctl.getPolicy().getX509()), }, nil } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 80dfc66e..2eefd331 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -13,9 +13,11 @@ import ( "reflect" "time" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/errs" ) // DefaultCertValidity is the default validity for a certificate if none is specified. @@ -402,6 +404,32 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return nil } +// x509NamePolicyValidator validates that the certificate (to be signed) +// contains only allowed SANs. +type x509NamePolicyValidator struct { + policyEngine policy.X509Policy +} + +// newX509NamePolicyValidator return a new SANs allow/deny validator. +func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator { + return &x509NamePolicyValidator{ + policyEngine: engine, + } +} + +// Valid validates that the certificate (to be signed) contains only allowed SANs. +func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error { + if v.policyEngine == nil { + return nil + } + return v.policyEngine.IsX509CertificateAllowed(cert) +} + +// var ( +// stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} +// stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) +// ) + // type stepProvisionerASN1 struct { // Type int // Name []byte diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index a2ca78b1..70dffba2 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -4,11 +4,13 @@ import ( "crypto/rsa" "encoding/binary" "encoding/json" + "fmt" "math/big" "strings" "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" @@ -444,6 +446,53 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti } } +// sshNamePolicyValidator validates that the certificate (to be signed) +// contains only allowed principals. +type sshNamePolicyValidator struct { + hostPolicyEngine policy.HostPolicy + userPolicyEngine policy.UserPolicy +} + +// newSSHNamePolicyValidator return a new SSH allow/deny validator. +func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator { + return &sshNamePolicyValidator{ + hostPolicyEngine: host, + userPolicyEngine: user, + } +} + +// Valid validates that the certificate (to be signed) contains only allowed principals. +func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { + if v.hostPolicyEngine == nil && v.userPolicyEngine == nil { + // no policy configured at all; allow anything + return nil + } + + // Check the policy type to execute based on type of the certificate. + // We don't allow user certs if only a host policy engine is configured and + // the same for host certs: if only a user policy engine is configured, host + // certs are denied. When both policy engines are configured, the type of + // cert determines which policy engine is used. + switch cert.CertType { + case ssh.HostCert: + // when no host policy engine is configured, but a user policy engine is + // configured, the host certificate is denied. + if v.hostPolicyEngine == nil && v.userPolicyEngine != nil { + return errors.New("SSH host certificate not authorized") + } + return v.hostPolicyEngine.IsSSHCertificateAllowed(cert) + case ssh.UserCert: + // when no user policy engine is configured, but a host policy engine is + // configured, the user certificate is denied. + if v.userPolicyEngine == nil && v.hostPolicyEngine != nil { + return errors.New("SSH user certificate not authorized") + } + return v.userPolicyEngine.IsSSHCertificateAllowed(cert) + default: + return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen + } +} + // sshCertTypeUInt32 func sshCertTypeUInt32(ct string) uint32 { switch ct { diff --git a/authority/provisioner/ssh_options.go b/authority/provisioner/ssh_options.go index 7ee236d1..93633a21 100644 --- a/authority/provisioner/ssh_options.go +++ b/authority/provisioner/ssh_options.go @@ -6,6 +6,8 @@ import ( "github.com/pkg/errors" "go.step.sm/crypto/sshutil" + + "github.com/smallstep/certificates/authority/policy" ) // SSHCertificateOptions is an interface that returns a list of options passed when @@ -33,6 +35,60 @@ type SSHOptions struct { // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` + + // User contains SSH user certificate options. + User *policy.SSHUserCertificateOptions `json:"-"` + + // Host contains SSH host certificate options. + Host *policy.SSHHostCertificateOptions `json:"-"` +} + +// GetAllowedUserNameOptions returns the SSHNameOptions that are +// allowed when SSH User certificates are requested. +func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.AllowedNames +} + +// GetDeniedUserNameOptions returns the SSHNameOptions that are +// denied when SSH user certificates are requested. +func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.DeniedNames +} + +// GetAllowedHostNameOptions returns the SSHNameOptions that are +// allowed when SSH host certificates are requested. +func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.AllowedNames +} + +// GetDeniedHostNameOptions returns the SSHNameOptions that are +// denied when SSH host certificates are requested. +func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.DeniedNames } // HasTemplate returns true if a template is defined in the provisioner options. diff --git a/authority/provisioner/ssh_test.go b/authority/provisioner/ssh_test.go index c530cd3c..90271443 100644 --- a/authority/provisioner/ssh_test.go +++ b/authority/provisioner/ssh_test.go @@ -53,6 +53,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []Si for _, op := range signOpts { switch o := op.(type) { + case Interface: // add options to NewCertificate case SSHCertificateOptions: certOptions = append(certOptions, o.Options(opts)...) diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 9de0fca2..c0246729 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -8,9 +8,11 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/jose" + + "github.com/smallstep/certificates/errs" ) // sshPOPPayload extends jwt.Claims with step attributes. @@ -95,7 +97,7 @@ func (p *SSHPOP) Init(config Config) (err error) { p.sshPubKeys = config.SSHKeys config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, nil) return } @@ -220,6 +222,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ + p, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 13294866..1e026883 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -459,9 +459,10 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 3, opts) + assert.Len(t, 4, opts) for _, o := range opts { switch v := o.(type) { + case Interface: case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index c55c58d2..0a1d176c 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -24,22 +24,22 @@ import ( ) var ( - defaultDisableRenewal = false - defaultAllowRenewAfterExpiry = false - defaultEnableSSHCA = true - globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, - DisableRenewal: &defaultDisableRenewal, - AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry, + defaultDisableRenewal = false + defaultAllowRenewalAfterExpiry = false + defaultEnableSSHCA = true + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, + DisableRenewal: &defaultDisableRenewal, + AllowRenewalAfterExpiry: &defaultAllowRenewalAfterExpiry, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, @@ -184,7 +184,7 @@ func generateJWK() (*JWK, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -219,7 +219,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -256,7 +256,7 @@ func generateSSHPOP() (*SSHPOP, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -305,7 +305,7 @@ M46l92gdOozT } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -343,7 +343,7 @@ func generateOIDC() (*OIDC, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -373,7 +373,7 @@ func generateGCP() (*GCP, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("gcp/" + name), - }) + }, nil) return p, err } @@ -411,7 +411,7 @@ func generateAWS() (*AWS, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("aws/" + name), - }) + }, nil) return p, err } @@ -518,7 +518,7 @@ func generateAWSV1Only() (*AWS, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("aws/" + name), - }) + }, nil) return p, err } @@ -608,7 +608,7 @@ func generateAzure() (*Azure, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 51b5d8fd..b9ae24c5 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -8,10 +8,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // x5cPayload extends jwt.Claims with step attributes. @@ -121,7 +123,7 @@ func (p *X5C) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -220,6 +222,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeX5C, p.Name, ""), @@ -232,6 +235,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -308,6 +312,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, } return append(signOptions, + p, // Checks the validity bounds, and set the validity if has not been set. &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter}, // Validate public key. @@ -316,5 +321,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), ), nil } diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 22dd8541..3bcf30d1 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -468,9 +468,10 @@ func TestX5C_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - assert.Equals(t, len(opts), 7) + assert.Equals(t, 9, len(opts)) for _, o := range opts { switch v := o.(type) { + case *X5C: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeX5C) @@ -479,7 +480,6 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) @@ -491,6 +491,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } @@ -767,6 +769,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { nw := now() for _, o := range opts { switch v := o.(type) { + case Interface: case sshCertOptionsValidator: tc.claims.Step.SSH.ValidAfter.t = time.Time{} tc.claims.Step.SSH.ValidBefore.t = time.Time{} @@ -787,6 +790,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) + case *sshNamePolicyValidator: + assert.Equals(t, nil, v.userPolicyEngine) + assert.Equals(t, nil, v.hostPolicyEngine) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) @@ -794,9 +800,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tot++ } if len(tc.claims.Step.SSH.CertType) > 0 { - assert.Equals(t, tot, 9) + assert.Equals(t, tot, 11) } else { - assert.Equals(t, tot, 7) + assert.Equals(t, tot, 9) } } } diff --git a/authority/provisioners.go b/authority/provisioners.go index a6ac5aa8..1fd34ef0 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -10,15 +10,19 @@ import ( "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/config" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/errs" + "gopkg.in/square/go-jose.v2/jwt" + "go.step.sm/cli-utils/step" "go.step.sm/cli-utils/ui" "go.step.sm/crypto/jose" "go.step.sm/linkedca" - "gopkg.in/square/go-jose.v2/jwt" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) // GetEncryptedKey returns the JWE key corresponding to the given kid argument. @@ -46,13 +50,43 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() + if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil { + return p, nil + } + return a.unsafeLoadProvisionerFromExtension(crt) +} + +func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) { p, ok := a.provisioners.LoadByCertificate(crt) - if !ok { + if !ok || p.GetType() == 0 { return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") } return p, nil } +func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) { + // certificateDataGetter is an interface that can be used to retrieve the + // provisioner from a db or a linked ca. + type certificateDataGetter interface { + GetCertificateData(string) (*db.CertificateData, error) + } + + var err error + var data *db.CertificateData + + if cdg, ok := a.adminDB.(certificateDataGetter); ok { + data, err = cdg.GetCertificateData(crt.SerialNumber.String()) + } else if cdg, ok := a.db.(certificateDataGetter); ok { + data, err = cdg.GetCertificateData(crt.SerialNumber.String()) + } + if err == nil && data != nil && data.Provisioner != nil { + if p, ok := a.provisioners.Load(data.Provisioner.ID); ok { + return p, nil + } + } + return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") +} + // LoadProvisionerByToken returns an interface to the provisioner that // provisioned the token. func (a *Authority) LoadProvisionerByToken(token *jwt.JSONWebToken, claims *jwt.Claims) (provisioner.Interface, error) { @@ -103,7 +137,6 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner. return provisioner.Config{ Claims: claimer.Claims(), Audiences: a.config.GetAudiences(), - DB: a.db, SSHKeys: &provisioner.SSHKeys{ UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, @@ -115,7 +148,7 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner. } -// StoreProvisioner stores an provisioner.Interface to the authority. +// StoreProvisioner stores a provisioner to the authority. func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() @@ -140,6 +173,10 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi return admin.WrapErrorISE(err, "error generating provisioner config") } + if err := a.checkProvisionerPolicy(ctx, prov.Name, prov.Policy); err != nil { + return err + } + if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name) } @@ -161,7 +198,7 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi } if err := a.provisioners.Store(certProv); err != nil { - if err := a.reloadAdminResources(ctx); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner store") } return admin.WrapErrorISE(err, "error storing provisioner in authority cache") @@ -185,6 +222,10 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio return admin.WrapErrorISE(err, "error generating provisioner config") } + if err := a.checkProvisionerPolicy(ctx, nu.Name, nu.Policy); err != nil { + return err + } + if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name) } @@ -193,7 +234,7 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio return admin.WrapErrorISE(err, "error updating provisioner '%s' in authority cache", nu.Name) } if err := a.adminDB.UpdateProvisioner(ctx, nu); err != nil { - if err := a.reloadAdminResources(ctx); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner update") } return admin.WrapErrorISE(err, "error updating provisioner '%s'", nu.Name) @@ -213,31 +254,33 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error { } provName, provID := p.GetName(), p.GetID() - // Validate - // - Check that there will be SUPER_ADMINs that remain after we - // remove this provisioner. - if a.admins.SuperCount() == a.admins.SuperCountByProvisioner(provName) { - return admin.NewError(admin.ErrorBadRequestType, - "cannot remove provisioner %s because no super admins will remain", provName) - } + if a.IsAdminAPIEnabled() { + // Validate + // - Check that there will be SUPER_ADMINs that remain after we + // remove this provisioner. + if a.IsAdminAPIEnabled() && a.admins.SuperCount() == a.admins.SuperCountByProvisioner(provName) { + return admin.NewError(admin.ErrorBadRequestType, + "cannot remove provisioner %s because no super admins will remain", provName) + } - // Delete all admins associated with the provisioner. - admins, ok := a.admins.LoadByProvisioner(provName) - if ok { - for _, adm := range admins { - if err := a.removeAdmin(ctx, adm.Id); err != nil { - return admin.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, provName) + // Delete all admins associated with the provisioner. + admins, ok := a.admins.LoadByProvisioner(provName) + if ok { + for _, adm := range admins { + if err := a.removeAdmin(ctx, adm.Id); err != nil { + return admin.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, provName) + } } } } // Remove provisioner from authority caches. if err := a.provisioners.Remove(provID); err != nil { - return admin.WrapErrorISE(err, "error removing admin from authority cache") + return admin.WrapErrorISE(err, "error removing provisioner from authority cache") } // Remove provisioner from database. if err := a.adminDB.DeleteProvisioner(ctx, provID); err != nil { - if err := a.reloadAdminResources(ctx); err != nil { + if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner remove") } return admin.WrapErrorISE(err, "error deleting provisioner %s", provName) @@ -247,7 +290,7 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error { // CreateFirstProvisioner creates and stores the first provisioner when using // admin database provisioner storage. -func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { +func CreateFirstProvisioner(ctx context.Context, adminDB admin.DB, password string) (*linkedca.Provisioner, error) { if password == "" { pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") if err != nil { @@ -290,7 +333,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) ( }, }, } - if err := db.CreateProvisioner(ctx, p); err != nil { + if err := adminDB.CreateProvisioner(ctx, p); err != nil { return nil, admin.WrapErrorISE(err, "error creating provisioner") } return p, nil @@ -397,6 +440,60 @@ func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options { ops.SSH.Template = string(p.SshTemplate.Template) ops.SSH.TemplateData = p.SshTemplate.Data } + if pol := p.GetPolicy(); pol != nil { + if x := pol.GetX509(); x != nil { + if allow := x.GetAllow(); allow != nil { + ops.X509.AllowedNames = &policy.X509NameOptions{ + DNSDomains: allow.Dns, + IPRanges: allow.Ips, + EmailAddresses: allow.Emails, + URIDomains: allow.Uris, + } + } + if deny := x.GetDeny(); deny != nil { + ops.X509.DeniedNames = &policy.X509NameOptions{ + DNSDomains: deny.Dns, + IPRanges: deny.Ips, + EmailAddresses: deny.Emails, + URIDomains: deny.Uris, + } + } + } + if ssh := pol.GetSsh(); ssh != nil { + if host := ssh.GetHost(); host != nil { + ops.SSH.Host = &policy.SSHHostCertificateOptions{} + if allow := host.GetAllow(); allow != nil { + ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{ + DNSDomains: allow.Dns, + IPRanges: allow.Ips, + Principals: allow.Principals, + } + } + if deny := host.GetDeny(); deny != nil { + ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{ + DNSDomains: deny.Dns, + IPRanges: deny.Ips, + Principals: deny.Principals, + } + } + } + if user := ssh.GetUser(); user != nil { + ops.SSH.User = &policy.SSHUserCertificateOptions{} + if allow := user.GetAllow(); allow != nil { + ops.SSH.User.AllowedNames = &policy.SSHNameOptions{ + EmailAddresses: allow.Emails, + Principals: allow.Principals, + } + } + if deny := user.GetDeny(); deny != nil { + ops.SSH.User.DeniedNames = &policy.SSHNameOptions{ + EmailAddresses: deny.Emails, + Principals: deny.Principals, + } + } + } + } + } return ops } @@ -437,8 +534,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { } pc := &provisioner.Claims{ - DisableRenewal: &c.DisableRenewal, - AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry, + DisableRenewal: &c.DisableRenewal, + AllowRenewalAfterExpiry: &c.AllowRenewalAfterExpiry, } var err error @@ -476,18 +573,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims { } disableRenewal := config.DefaultDisableRenewal - allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry + allowRenewalAfterExpiry := config.DefaultAllowRenewalAfterExpiry if c.DisableRenewal != nil { disableRenewal = *c.DisableRenewal } - if c.AllowRenewAfterExpiry != nil { - allowRenewAfterExpiry = *c.AllowRenewAfterExpiry + if c.AllowRenewalAfterExpiry != nil { + allowRenewalAfterExpiry = *c.AllowRenewalAfterExpiry } lc := &linkedca.Claims{ - DisableRenewal: disableRenewal, - AllowRenewAfterExpiry: allowRenewAfterExpiry, + DisableRenewal: disableRenewal, + AllowRenewalAfterExpiry: allowRenewalAfterExpiry, } if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 81dc38bf..56cd16b1 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,13 +1,21 @@ package authority import ( + "context" + "crypto/x509" "errors" "net/http" + "reflect" "testing" + "time" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" ) func TestGetEncryptedKey(t *testing.T) { @@ -67,6 +75,15 @@ func TestGetEncryptedKey(t *testing.T) { } } +type mockAdminDB struct { + admin.MockDB + MGetCertificateData func(string) (*db.CertificateData, error) +} + +func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) { + return c.MGetCertificateData(sn) +} + func TestGetProvisioners(t *testing.T) { type gp struct { a *Authority @@ -104,3 +121,133 @@ func TestGetProvisioners(t *testing.T) { }) } } + +func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { + _, priv, err := keyutil.GenerateDefaultKeyPair() + assert.FatalError(t, err) + csr := getCSR(t, priv) + + sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate { + key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) + assert.FatalError(t, err) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + opts, err := a.Authorize(ctx, token) + assert.FatalError(t, err) + opts = append(opts, extraOpts...) + certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + assert.FatalError(t, err) + return certs[0] + } + getProvisioner := func(a *Authority, name string) provisioner.Interface { + p, ok := a.provisioners.LoadByName(name) + if !ok { + t.Fatalf("provisioner %s does not exists", name) + } + return p + } + removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + for i, ext := range cert.ExtraExtensions { + if ext.Id.Equal(provisioner.StepOIDProvisioner) { + cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...) + break + } + } + return nil + }) + + a0 := testAuthority(t) + + a1 := testAuthority(t) + a1.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + p, err := a1.LoadProvisionerByName("dev") + if err != nil { + t.Fatal(err) + } + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: p.GetID(), + Name: p.GetName(), + Type: p.GetType().String(), + }, + }, nil + }, + } + + a2 := testAuthority(t) + a2.adminDB = &mockAdminDB{ + MGetCertificateData: (func(s string) (*db.CertificateData, error) { + p, err := a2.LoadProvisionerByName("dev") + if err != nil { + t.Fatal(err) + } + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: p.GetID(), + Name: p.GetName(), + Type: p.GetType().String(), + }, + }, nil + }), + } + + a3 := testAuthority(t) + a3.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: "foo", Name: "foo", Type: "foo", + }, + }, nil + }, + } + + a4 := testAuthority(t) + a4.adminDB = &mockAdminDB{ + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: "foo", Name: "foo", Type: "foo", + }, + }, nil + }, + } + + type args struct { + crt *x509.Certificate + } + tests := []struct { + name string + authority *Authority + args args + want provisioner.Interface + wantErr bool + }{ + {"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false}, + {"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false}, + {"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false}, + {"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true}, + {"fail from db", a3, args{sign(a3, removeExtension)}, nil, true}, + {"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/ssh.go b/authority/ssh.go index 4a67b28c..1fd7f2e8 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -5,18 +5,23 @@ import ( "crypto/rand" "crypto/x509" "encoding/binary" + "errors" + "fmt" "net/http" "strings" "time" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/randutil" + "go.step.sm/crypto/sshutil" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" + policy "github.com/smallstep/certificates/policy" "github.com/smallstep/certificates/templates" - "go.step.sm/crypto/randutil" - "go.step.sm/crypto/sshutil" - "golang.org/x/crypto/ssh" ) const ( @@ -156,8 +161,13 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Set backdate with the configured value opts.Backdate = a.config.AuthorityConfig.Backdate.Duration + var prov provisioner.Interface for _, op := range signOpts { switch o := op.(type) { + // Capture current provisioner + case provisioner.Interface: + prov = o + // add options to NewCertificate case provisioner.SSHCertificateOptions: certOptions = append(certOptions, o.Options(opts)...) @@ -241,6 +251,23 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType) } + // Check if authority is allowed to sign the certificate + if err := a.isAllowedToSignSSHCertificate(certTpl); err != nil { + var pe *policy.NamePolicyError + if errors.As(err, &pe) && pe.Reason == policy.NotAllowed { + return nil, &errs.Error{ + // NOTE: custom forbidden error, so that denied name is sent to client + // as well as shown in the logs. + Status: http.StatusForbidden, + Err: fmt.Errorf("authority not allowed to sign: %w", err), + Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()), + } + } + return nil, errs.InternalServerErr(err, + errs.WithMessage("authority.SignSSH: error creating ssh certificate"), + ) + } + // Sign certificate. cert, err := sshutil.CreateCertificate(certTpl, signer) if err != nil { @@ -254,13 +281,18 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi } } - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeSSHCertificate(prov, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db") } return cert, nil } +// isAllowedToSignSSHCertificate checks if the Authority is allowed to sign the SSH certificate. +func (a *Authority) isAllowedToSignSSHCertificate(cert *ssh.Certificate) error { + return a.policyEngine.IsSSHCertificateAllowed(cert) +} + // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { @@ -271,6 +303,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, err } + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + backdate := a.config.AuthorityConfig.Backdate.Duration duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second now := time.Now() @@ -313,7 +351,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") } @@ -324,8 +362,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var validators []provisioner.SSHCertValidator + var prov provisioner.Interface for _, op := range signOpts { switch o := op.(type) { + // Capture current provisioner + case provisioner.Interface: + prov = o // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) @@ -392,21 +434,59 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } } - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") } return cert, nil } -func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error { +func (a *Authority) storeSSHCertificate(prov provisioner.Interface, cert *ssh.Certificate) error { type sshCertificateStorer interface { - StoreSSHCertificate(crt *ssh.Certificate) error + StoreSSHCertificate(provisioner.Interface, *ssh.Certificate) error + } + + // Store certificate in admindb or linkedca + switch s := a.adminDB.(type) { + case sshCertificateStorer: + return s.StoreSSHCertificate(prov, cert) + case db.CertificateStorer: + return s.StoreSSHCertificate(cert) + } + + // Store certificate in localdb + switch s := a.db.(type) { + case sshCertificateStorer: + return s.StoreSSHCertificate(prov, cert) + case db.CertificateStorer: + return s.StoreSSHCertificate(cert) + default: + return nil + } +} + +func (a *Authority) storeRenewedSSHCertificate(prov provisioner.Interface, parent, cert *ssh.Certificate) error { + type sshRenewerCertificateStorer interface { + StoreRenewedSSHCertificate(p provisioner.Interface, parent, cert *ssh.Certificate) error } - if s, ok := a.adminDB.(sshCertificateStorer); ok { + + // Store certificate in admindb or linkedca + switch s := a.adminDB.(type) { + case sshRenewerCertificateStorer: + return s.StoreRenewedSSHCertificate(prov, parent, cert) + case db.CertificateStorer: return s.StoreSSHCertificate(cert) } - return a.db.StoreSSHCertificate(cert) + + // Store certificate in localdb + switch s := a.db.(type) { + case sshRenewerCertificateStorer: + return s.StoreRenewedSSHCertificate(prov, parent, cert) + case db.CertificateStorer: + return s.StoreSSHCertificate(cert) + default: + return nil + } } // IsValidForAddUser checks if a user provisioner certificate can be issued to @@ -452,6 +532,12 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number") } + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + signer := a.sshCAUserCertSignKey principal := subject.ValidPrincipals[0] addUserPrincipal := a.getAddUserPrincipal() @@ -484,7 +570,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje } cert.Signature = sig - if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, subject, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") } diff --git a/authority/ssh_test.go b/authority/ssh_test.go index ce840fe1..4fd7eaa0 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -20,6 +20,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/templates" @@ -159,6 +160,14 @@ func TestAuthority_SignSSH(t *testing.T) { assert.FatalError(t, err) hostTemplateWithHosts, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"foo.test.com", "bar.test.com"})) assert.FatalError(t, err) + userTemplateWithRoot, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"root"})) + assert.FatalError(t, err) + hostTemplateWithExampleDotCom, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"example.com"})) + assert.FatalError(t, err) + badUserTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"127.0.0.1"})) + assert.FatalError(t, err) + badHostTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"host...local"})) + assert.FatalError(t, err) userCustomTemplate, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{Template: `{ "type": "{{ .Type }}", @@ -182,11 +191,36 @@ func TestAuthority_SignSSH(t *testing.T) { }, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) + userPolicyOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"user"}, + }, + }, + }, + } + userPolicy, err := policy.New(userPolicyOptions) + assert.FatalError(t, err) + + hostPolicyOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.test.com"}, + }, + }, + }, + } + hostPolicy, err := policy.New(hostPolicyOptions) + assert.FatalError(t, err) + now := time.Now() type fields struct { sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer + policyEngine *policy.Engine } type args struct { key ssh.PublicKey @@ -206,39 +240,48 @@ func TestAuthority_SignSSH(t *testing.T) { want want wantErr bool }{ - {"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, - {"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, - {"ok-user-only", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, - {"ok-host-only", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, - {"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, - {"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, - {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, - {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, - {"ok-opts-valid-after", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, - {"ok-opts-valid-before", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, - {"ok-cert-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-cert-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-opts-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-opts-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-custom-template", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false}, - {"fail-opts-type", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true}, - {"fail-cert-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true}, - {"fail-cert-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true}, - {"fail-opts-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, - {"fail-opts-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, - {"fail-bad-sign-options", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true}, - {"fail-no-user-key", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true}, - {"fail-no-host-key", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true}, - {"fail-bad-type", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true}, - {"fail-custom-template", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true}, - {"fail-custom-template-syntax-error-file", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true}, - {"fail-custom-template-syntax-value-file", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true}, + {"ok-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, + {"ok-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, + {"ok-user-only", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, + {"ok-host-only", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, + {"ok-opts-type-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-type-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, + {"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, + {"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, + {"ok-opts-valid-after", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, + {"ok-opts-valid-before", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, + {"ok-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false}, + {"ok-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, + {"ok-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, + {"fail-opts-type", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true}, + {"fail-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true}, + {"fail-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true}, + {"fail-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, + {"fail-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, + {"fail-bad-sign-options", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true}, + {"fail-no-user-key", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true}, + {"fail-no-host-key", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true}, + {"fail-bad-type", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true}, + {"fail-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true}, + {"fail-custom-template-syntax-error-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true}, + {"fail-custom-template-syntax-value-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true}, + {"fail-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"root"}}, []provisioner.SignOption{userTemplateWithRoot}}, want{}, true}, + {"fail-user-policy-with-host-cert", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true}, + {"fail-user-policy-with-bad-user", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{badUserTemplate}}, want{}, true}, + {"fail-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true}, + {"fail-host-policy-with-user-cert", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{}, true}, + {"fail-host-policy-with-bad-host", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{badHostTemplate}}, want{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey + a.policyEngine = tt.fields.policyEngine got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...) if (err != nil) != tt.wantErr { diff --git a/authority/status/status.go b/authority/status/status.go deleted file mode 100644 index 49e4c0bb..00000000 --- a/authority/status/status.go +++ /dev/null @@ -1,11 +0,0 @@ -package status - -// Type is the type for status. -type Type string - -var ( - // Active active - Active = Type("active") - // Deleted deleted - Deleted = Type("deleted") -) diff --git a/authority/tls.go b/authority/tls.go index 50dc5642..55e3b49c 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -18,16 +18,19 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/x509util" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" - "go.step.sm/crypto/pemutil" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" + "github.com/smallstep/certificates/policy" ) // GetTLSOptions returns the tls options configured. @@ -91,8 +94,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign // Set backdate with the configured value signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration + var prov provisioner.Interface for _, op := range extraOpts { switch k := op.(type) { + // Capture current provisioner + case provisioner.Interface: + prov = k + // Adds new options to NewCertificate case provisioner.CertificateOptions: certOptions = append(certOptions, k.Options(signOpts)...) @@ -193,6 +201,25 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } + // Check if authority is allowed to sign the certificate + if err := a.isAllowedToSignX509Certificate(leaf); err != nil { + var pe *policy.NamePolicyError + if errors.As(err, &pe) && pe.Reason == policy.NotAllowed { + return nil, errs.ApplyOptions(&errs.Error{ + // NOTE: custom forbidden error, so that denied name is sent to client + // as well as shown in the logs. + Status: http.StatusForbidden, + Err: fmt.Errorf("authority not allowed to sign: %w", err), + Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()), + }, opts...) + } + return nil, errs.InternalServerErr(err, + errs.WithKeyVal("csr", csr), + errs.WithKeyVal("signOptions", signOpts), + errs.WithMessage("error creating certificate"), + ) + } + // Sign certificate lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ @@ -206,7 +233,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) - if err = a.storeCertificate(fullchain); err != nil { + if err = a.storeCertificate(prov, fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error storing certificate in db", opts...) @@ -216,6 +243,18 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign return fullchain, nil } +// isAllowedToSignX509Certificate checks if the Authority is allowed +// to sign the X.509 certificate. +func (a *Authority) isAllowedToSignX509Certificate(cert *x509.Certificate) error { + return a.policyEngine.IsX509CertificateAllowed(cert) +} + +// AreSANsAllowed evaluates the provided sans against the +// authority X.509 policy. +func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error { + return a.policyEngine.AreSANsAllowed(sans) +} + // Renew creates a new Certificate identical to the old certificate, except // with a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { @@ -327,19 +366,33 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 // TODO: at some point we should replace the db.AuthDB interface to implement // `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(*x509.Certificate) error`. -func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { +func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error { type certificateChainStorer interface { + StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error + } + type certificateChainSimpleStorer interface { StoreCertificateChain(...*x509.Certificate) error } + // Store certificate in linkedca - if s, ok := a.adminDB.(certificateChainStorer); ok { + switch s := a.adminDB.(type) { + case certificateChainStorer: + return s.StoreCertificateChain(prov, fullchain...) + case certificateChainSimpleStorer: return s.StoreCertificateChain(fullchain...) } + // Store certificate in local db - if s, ok := a.db.(certificateChainStorer); ok { + switch s := a.db.(type) { + case certificateChainStorer: + return s.StoreCertificateChain(prov, fullchain...) + case certificateChainSimpleStorer: return s.StoreCertificateChain(fullchain...) + case db.CertificateStorer: + return s.StoreCertificate(fullchain[0]) + default: + return nil } - return a.db.StoreCertificate(fullchain[0]) } // storeRenewedCertificate allows to use an extension of the db.AuthDB interface @@ -350,15 +403,21 @@ func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain type renewedCertificateChainStorer interface { StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error } + // Store certificate in linkedca if s, ok := a.adminDB.(renewedCertificateChainStorer); ok { return s.StoreRenewedCertificate(oldCert, fullchain...) } + // Store certificate in local db - if s, ok := a.db.(renewedCertificateChainStorer); ok { + switch s := a.db.(type) { + case renewedCertificateChainStorer: return s.StoreRenewedCertificate(oldCert, fullchain...) + case db.CertificateStorer: + return s.StoreCertificate(fullchain[0]) + default: + return nil } - return a.db.StoreCertificate(fullchain[0]) } // RevokeOptions are the options for the Revoke API. @@ -521,7 +580,7 @@ func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateIn }); ok { return lca.RevokeSSH(crt, rci) } - return a.db.Revoke(rci) + return a.db.RevokeSSH(rci) } // GetCertificateRevocationList will return the currently generated CRL from the DB, or a not implemented @@ -664,7 +723,7 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { } // Create initial certificate request. - cr, err := x509util.CreateCertificateRequest("Step Online CA", sans, signer) + cr, err := x509util.CreateCertificateRequest(a.config.CommonName, sans, signer) if err != nil { return fatal(err) } diff --git a/authority/tls_test.go b/authority/tls_test.go index e199e0c5..23d2f8fa 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -27,6 +27,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas/softcas" "github.com/smallstep/certificates/db" @@ -511,6 +512,39 @@ ZYtQ9Ot36qc= code: http.StatusForbidden, } }, + "fail with policy": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + aa := testAuthority(t) + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + aa.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + fmt.Println(crt.Subject) + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + return nil + }, + } + options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"test.smallstep.com"}, + }, + }, + } + engine, err := policy.New(options) + assert.FatalError(t, err) + aa.policyEngine = engine + return &signTest{ + auth: aa, + csr: csr, + extraOpts: extraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, + err: errors.New("authority not allowed to sign"), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) @@ -653,6 +687,38 @@ ZYtQ9Ot36qc= extensionsCount: 7, } }, + "ok with policy": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + aa := testAuthority(t) + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + aa.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + fmt.Println(crt.Subject) + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + return nil + }, + } + options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + CommonNames: []string{"smallstep test"}, + DNSDomains: []string{"*.smallstep.com"}, + }, + }, + } + engine, err := policy.New(options) + assert.FatalError(t, err) + aa.policyEngine = engine + return &signTest{ + auth: aa, + csr: csr, + extraOpts: extraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, + } + }, } for name, genTestCase := range tests { @@ -1235,8 +1301,11 @@ func TestAuthority_Revoke(t *testing.T) { a := testAuthority(t) + tlsRevokeCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) + type test struct { auth *Authority + ctx context.Context opts *RevokeOptions err error code int @@ -1246,6 +1315,7 @@ func TestAuthority_Revoke(t *testing.T) { "fail/token/authorizeRevoke error": func() test { return test{ auth: a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ OTT: "foo", Serial: "sn", @@ -1270,6 +1340,7 @@ func TestAuthority_Revoke(t *testing.T) { return test{ auth: a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, @@ -1309,6 +1380,7 @@ func TestAuthority_Revoke(t *testing.T) { return test{ auth: _a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, @@ -1348,6 +1420,7 @@ func TestAuthority_Revoke(t *testing.T) { return test{ auth: _a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, @@ -1385,6 +1458,7 @@ func TestAuthority_Revoke(t *testing.T) { assert.FatalError(t, err) return test{ auth: _a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, @@ -1401,6 +1475,7 @@ func TestAuthority_Revoke(t *testing.T) { return test{ auth: _a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", @@ -1425,6 +1500,7 @@ func TestAuthority_Revoke(t *testing.T) { return test{ auth: _a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", @@ -1442,6 +1518,7 @@ func TestAuthority_Revoke(t *testing.T) { return test{ auth: _a, + ctx: tlsRevokeCtx, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", @@ -1451,12 +1528,42 @@ func TestAuthority_Revoke(t *testing.T) { }, } }, + "ok/ssh": func() test { + a := testAuthority(t, WithDatabase(&db.MockAuthDB{ + MRevoke: func(rci *db.RevokedCertificateInfo) error { + return errors.New("Revoke was called") + }, + MRevokeSSH: func(rci *db.RevokedCertificateInfo) error { + return nil + }, + })) + + cl := jwt.Claims{ + Subject: "sn", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: validAudience, + ID: "44", + } + raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return test{ + auth: a, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), + opts: &RevokeOptions{ + Serial: "sn", + ReasonCode: reasonCode, + Reason: reason, + OTT: raw, + }, + } + }, } for name, f := range tests { tc := f() t.Run(name, func(t *testing.T) { - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) - if err := tc.auth.Revoke(ctx, tc.opts); err != nil { + if err := tc.auth.Revoke(tc.ctx, tc.opts); err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCodedError interface") diff --git a/ca/adminClient.go b/ca/adminClient.go index 5f3993b1..6532b000 100644 --- a/ca/adminClient.go +++ b/ca/adminClient.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/x509" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -12,18 +13,23 @@ import ( "time" "github.com/pkg/errors" - adminAPI "github.com/smallstep/certificates/authority/admin/api" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/errs" + "google.golang.org/protobuf/encoding/protojson" + "go.step.sm/cli-utils/token" "go.step.sm/cli-utils/token/provision" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "go.step.sm/linkedca" - "google.golang.org/protobuf/encoding/protojson" + + adminAPI "github.com/smallstep/certificates/authority/admin/api" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" ) -var adminURLPrefix = "admin" +const ( + adminURLPrefix = "admin" + adminIssuer = "step-admin-client/1.0" +) // AdminClient implements an HTTP client for the CA server. type AdminClient struct { @@ -35,7 +41,6 @@ type AdminClient struct { x5cCertFile string x5cCertStrs []string x5cCert *x509.Certificate - x5cIssuer string x5cSubject string } @@ -77,24 +82,30 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error) x5cCertFile: o.x5cCertFile, x5cCertStrs: o.x5cCertStrs, x5cCert: o.x5cCert, - x5cIssuer: o.x5cIssuer, x5cSubject: o.x5cSubject, }, nil } -func (c *AdminClient) generateAdminToken(urlPath string) (string, error) { +func (c *AdminClient) generateAdminToken(aud *url.URL) (string, error) { // A random jwt id will be used to identify duplicated tokens jwtID, err := randutil.Hex(64) // 256 bits if err != nil { return "", err } + // Drop any query string parameter from the token audience + aud = &url.URL{ + Scheme: aud.Scheme, + Host: aud.Host, + Path: aud.Path, + } + now := time.Now() tokOptions := []token.Options{ token.WithJWTID(jwtID), token.WithKid(c.x5cJWK.KeyID), - token.WithIssuer(c.x5cIssuer), - token.WithAudience(urlPath), + token.WithIssuer(adminIssuer), + token.WithAudience(aud.String()), token.WithValidity(now, now.Add(token.DefaultValidity)), token.WithX5CCerts(c.x5cCertStrs), } @@ -205,7 +216,7 @@ func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdmin Path: "/admin/admins", RawQuery: o.rawQuery(), }) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -260,7 +271,7 @@ func (c *AdminClient) CreateAdmin(createAdminRequest *adminAPI.CreateAdminReques return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/admin/admins"}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -292,7 +303,7 @@ retry: func (c *AdminClient) RemoveAdmin(id string) error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } @@ -324,7 +335,7 @@ func (c *AdminClient) UpdateAdmin(id string, uar *adminAPI.UpdateAdminRequest) ( return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -355,23 +366,23 @@ retry: // GetProvisioner performs the GET /admin/provisioners/{name} request to the CA. func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provisioner, error) { var retried bool - o := new(provisionerOptions) - if err := o.apply(opts); err != nil { + o := new(ProvisionerOptions) + if err := o.Apply(opts); err != nil { return nil, err } var u *url.URL switch { - case len(o.id) > 0: + case o.ID != "": u = c.endpoint.ResolveReference(&url.URL{ Path: "/admin/provisioners/id", RawQuery: o.rawQuery(), }) - case len(o.name) > 0: - u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) + case o.Name != "": + u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.Name)}) default: return nil, errors.New("must set either name or id in method options") } - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -402,15 +413,15 @@ retry: // GetProvisionersPaginate performs the GET /admin/provisioners request to the CA. func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*adminAPI.GetProvisionersResponse, error) { var retried bool - o := new(provisionerOptions) - if err := o.apply(opts); err != nil { + o := new(ProvisionerOptions) + if err := o.Apply(opts); err != nil { return nil, err } u := c.endpoint.ResolveReference(&url.URL{ Path: "/admin/provisioners", RawQuery: o.rawQuery(), }) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -464,23 +475,23 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error { retried bool ) - o := new(provisionerOptions) - if err := o.apply(opts); err != nil { + o := new(ProvisionerOptions) + if err := o.Apply(opts); err != nil { return err } switch { - case len(o.id) > 0: + case o.ID != "": u = c.endpoint.ResolveReference(&url.URL{ Path: path.Join(adminURLPrefix, "provisioners/id"), RawQuery: o.rawQuery(), }) - case len(o.name) > 0: - u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) + case o.Name != "": + u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.Name)}) default: return errors.New("must set either name or id in method options") } - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } @@ -512,7 +523,7 @@ func (c *AdminClient) CreateProvisioner(prov *linkedca.Provisioner) (*linkedca.P return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners")}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -548,7 +559,7 @@ func (c *AdminClient) UpdateProvisioner(name string, prov *linkedca.Provisioner) return errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", name)}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } @@ -587,7 +598,7 @@ func (c *AdminClient) GetExternalAccountKeysPaginate(provisionerName, reference Path: p, RawQuery: o.rawQuery(), }) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -623,7 +634,7 @@ func (c *AdminClient) CreateExternalAccountKey(provisionerName string, eakReques return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab/", provisionerName)}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } @@ -655,7 +666,7 @@ retry: func (c *AdminClient) RemoveExternalAccountKey(provisionerName, keyID string) error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab", provisionerName, "/", keyID)}) - tok, err := c.generateAdminToken(u.Path) + tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } @@ -679,6 +690,418 @@ retry: return nil } +func (c *AdminClient) GetAuthorityPolicy() (*linkedca.Policy, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client GET %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) CreateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client POST %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) UpdateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client PUT %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) RemoveAuthorityPolicy() error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) + if err != nil { + return fmt.Errorf("creating DELETE %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("client DELETE %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +func (c *AdminClient) GetProvisionerPolicy(provisionerName string) (*linkedca.Policy, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client GET %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) CreateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client POST %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) UpdateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client PUT %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) RemoveProvisionerPolicy(provisionerName string) error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) + if err != nil { + return fmt.Errorf("creating DELETE %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("client DELETE %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +func (c *AdminClient) GetACMEPolicy(provisionerName, reference, keyID string) (*linkedca.Policy, error) { + var retried bool + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client GET %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) CreateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client POST %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) UpdateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client PUT %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) RemoveACMEPolicy(provisionerName, reference, keyID string) error { + var retried bool + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) + if err != nil { + return fmt.Errorf("creating DELETE %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("client DELETE %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + func readAdminError(r io.ReadCloser) error { // TODO: not all errors can be read (i.e. 404); seems to be a bigger issue defer r.Close() diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 0e0f0fe3..430f2e31 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -48,17 +48,18 @@ func Bootstrap(token string) (*Client, error) { // certificate after 2/3rd of the certificate's lifetime has expired. // // Usage: -// // Default example with certificate rotation. -// client, err := ca.BootstrapClient(ctx.Background(), token) // -// // Example canceling automatic certificate rotation. -// ctx, cancel := context.WithCancel(context.Background()) -// defer cancel() -// client, err := ca.BootstrapClient(ctx, token) -// if err != nil { -// return err -// } -// resp, err := client.Get("https://internal.smallstep.com") +// // Default example with certificate rotation. +// client, err := ca.BootstrapClient(ctx.Background(), token) +// +// // Example canceling automatic certificate rotation. +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// client, err := ca.BootstrapClient(ctx, token) +// if err != nil { +// return err +// } +// resp, err := client.Get("https://internal.smallstep.com") func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { b, err := createBootstrap(token) if err != nil { @@ -96,23 +97,24 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* // ca.AddClientCA(*x509.Certificate). // // Usage: -// // Default example with certificate rotation. -// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{ -// Addr: ":443", -// Handler: handler, -// }) // -// // Example canceling automatic certificate rotation. -// ctx, cancel := context.WithCancel(context.Background()) -// defer cancel() -// srv, err := ca.BootstrapServer(ctx, token, &http.Server{ -// Addr: ":443", -// Handler: handler, -// }) -// if err != nil { -// return err -// } -// srv.ListenAndServeTLS("", "") +// // Default example with certificate rotation. +// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{ +// Addr: ":443", +// Handler: handler, +// }) +// +// // Example canceling automatic certificate rotation. +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// srv, err := ca.BootstrapServer(ctx, token, &http.Server{ +// Addr: ":443", +// Handler: handler, +// }) +// if err != nil { +// return err +// } +// srv.ListenAndServeTLS("", "") func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) { if base.TLSConfig != nil { return nil, errors.New("server TLSConfig is already set") @@ -152,19 +154,20 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio // ca.AddClientCA(*x509.Certificate). // // Usage: -// inner, err := net.Listen("tcp", ":443") -// if err != nil { -// return nil -// } -// ctx, cancel := context.WithCancel(context.Background()) -// defer cancel() -// lis, err := ca.BootstrapListener(ctx, token, inner) -// if err != nil { -// return err -// } -// srv := grpc.NewServer() -// ... // register services -// srv.Serve(lis) +// +// inner, err := net.Listen("tcp", ":443") +// if err != nil { +// return nil +// } +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// lis, err := ca.BootstrapListener(ctx, token, inner) +// if err != nil { +// return err +// } +// srv := grpc.NewServer() +// ... // register services +// srv.Serve(lis) func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { b, err := createBootstrap(token) if err != nil { diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 2332b4d4..ccbdbc22 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/httptest" + "os" "reflect" "strings" "sync" @@ -53,7 +54,11 @@ func startCABootstrapServer() *httptest.Server { if err != nil { panic(err) } + baseContext := buildContext(ca.auth, nil, nil, nil) srv.Config.Handler = ca.srv.Handler + srv.Config.BaseContext = func(net.Listener) context.Context { + return baseContext + } srv.TLS = ca.srv.TLSConfig srv.StartTLS() // Force the use of GetCertificate on IPs @@ -92,6 +97,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han for _, s := range nonAuthenticatedPaths { if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) { next.ServeHTTP(w, r) + return } } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 @@ -369,6 +375,9 @@ func TestBootstrapClient(t *testing.T) { } func TestBootstrapClientServerRotation(t *testing.T) { + if os.Getenv("CI") == "true" { + t.Skipf("skip until we fix https://github.com/smallstep/certificates/issues/873") + } reset := setMinCertDuration(1 * time.Second) defer reset() diff --git a/ca/ca.go b/ca/ca.go index 0d4f1578..7c00bb6b 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -1,10 +1,12 @@ package ca import ( + "context" "crypto/tls" "crypto/x509" "fmt" "log" + "net" "net/http" "net/url" "reflect" @@ -18,6 +20,7 @@ import ( acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" adminAPI "github.com/smallstep/certificates/authority/admin/api" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/db" @@ -170,10 +173,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler := http.Handler(insecureMux) // Add regular CA api endpoints in / and /1.0 - routerHandler := api.New(auth) - routerHandler.Route(mux) + api.Route(mux) mux.Route("/1.0", func(r chi.Router) { - routerHandler.Route(r) + api.Route(r) }) //Add ACME api endpoints in /acme and /1.0/acme @@ -187,48 +189,41 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { dns = fmt.Sprintf("%s:%s", dns, port) } - // ACME Router - prefix := "acme" + // ACME Router is only available if we have a database. var acmeDB acme.DB - if cfg.DB == nil { - acmeDB = nil - } else { + var acmeLinker acme.Linker + if cfg.DB != nil { acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) if err != nil { return nil, errors.Wrap(err, "error configuring ACME DB interface") } + acmeLinker = acme.NewLinker(dns, "acme") + mux.Route("/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) + // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 + // of the ACME spec. + mux.Route("/2.0/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) } - acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ - Backdate: *cfg.AuthorityConfig.Backdate, - DB: acmeDB, - DNS: dns, - Prefix: prefix, - CA: auth, - }) - mux.Route("/"+prefix, func(r chi.Router) { - acmeHandler.Route(r) - }) - // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 - // of the ACME spec. - mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeHandler.Route(r) - }) // Admin API Router if cfg.AuthorityConfig.EnableAdmin { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() - adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder) + policyAdminResponder := adminAPI.NewPolicyAdminResponder() mux.Route("/admin", func(r chi.Router) { - adminHandler.Route(r) + adminAPI.Route(r, acmeAdminResponder, policyAdminResponder) }) } } + var scepAuthority *scep.Authority if ca.shouldServeSCEPEndpoints() { scepPrefix := "scep" - scepAuthority, err := scep.New(auth, scep.AuthorityOptions{ + scepAuthority, err = scep.New(auth, scep.AuthorityOptions{ Service: auth.GetSCEPService(), DNS: dns, Prefix: scepPrefix, @@ -236,13 +231,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { if err != nil { return nil, errors.Wrap(err, "error creating SCEP authority") } - scepRouterHandler := scepAPI.New(scepAuthority) // According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10), // SCEP operations are performed using HTTP, so that's why the API is mounted // to the insecure mux. insecureMux.Route("/"+scepPrefix, func(r chi.Router) { - scepRouterHandler.Route(r) + scepAPI.Route(r) }) // The RFC also mentions usage of HTTPS, but seems to advise @@ -252,7 +246,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // as well as HTTPS can be used to request certificates // using SCEP. mux.Route("/"+scepPrefix, func(r chi.Router) { - scepRouterHandler.Route(r) + scepAPI.Route(r) }) } @@ -279,7 +273,13 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = logger.Middleware(insecureHandler) } + // Create context with all the necessary values. + baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) + ca.srv = server.New(cfg.Address, handler, tlsConfig) + ca.srv.BaseContext = func(net.Listener) context.Context { + return baseContext + } // only start the insecure server if the insecure address is configured // and, currently, also only when it should serve SCEP endpoints. @@ -289,11 +289,32 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // will probably introduce more complexity in terms of graceful // reload. ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil) + ca.insecureSrv.BaseContext = func(net.Listener) context.Context { + return baseContext + } } return ca, nil } +// buildContext builds the server base context. +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { + ctx := authority.NewContext(context.Background(), a) + if authDB := a.GetDatabase(); authDB != nil { + ctx = db.NewContext(ctx, authDB) + } + if adminDB := a.GetAdminDatabase(); adminDB != nil { + ctx = admin.NewContext(ctx, adminDB) + } + if scepAuthority != nil { + ctx = scep.NewContext(ctx, scepAuthority) + } + if acmeDB != nil { + ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) + } + return ctx +} + // Run starts the CA calling to the server ListenAndServe method. func (ca *CA) Run() error { var wg sync.WaitGroup @@ -321,7 +342,7 @@ func (ca *CA) Run() error { log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt)) } if authorityInfo.SSHCAHostPublicKey != nil { - log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey) + log.Printf("SSH Host CA Key: %s\n", authorityInfo.SSHCAHostPublicKey) } if authorityInfo.SSHCAUserPublicKey != nil { log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey) @@ -502,7 +523,7 @@ func (ca *CA) shouldServeSCEPEndpoints() bool { return ca.auth.GetSCEPService() != nil } -//nolint // ignore linters to allow keeping this function around for debugging +// nolint // ignore linters to allow keeping this function around for debugging func dumpRoutes(mux chi.Routes) { // helpful routine for logging all routes // walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { diff --git a/ca/ca_test.go b/ca/ca_test.go index e4c35a90..29eac575 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -2,6 +2,7 @@ package ca import ( "bytes" + "context" "crypto" "crypto/rand" "crypto/sha1" @@ -281,7 +282,8 @@ ZEp7knvU2psWRw== assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -360,7 +362,8 @@ func TestCAProvisioners(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -426,7 +429,8 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -487,7 +491,8 @@ func TestCARoot(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -534,7 +539,8 @@ func TestCAHealth(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -628,7 +634,8 @@ func TestCARenew(t *testing.T) { rq.TLS = tc.tlsConnState rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} diff --git a/ca/client.go b/ca/client.go index 3a36fcd6..44961357 100644 --- a/ca/client.go +++ b/ca/client.go @@ -10,7 +10,6 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" - "encoding/asn1" "encoding/hex" "encoding/json" "encoding/pem" @@ -116,7 +115,6 @@ type clientOptions struct { x5cCertFile string x5cCertStrs []string x5cCert *x509.Certificate - x5cIssuer string x5cSubject string } @@ -294,18 +292,6 @@ func WithCertificate(cert tls.Certificate) ClientOption { } } -var ( - stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} - stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) -) - -type stepProvisionerASN1 struct { - Type int - Name []byte - CredentialID []byte - KeyValuePairs []string `asn1:"optional,omitempty"` -} - // WithAdminX5C will set the given file as the X5C certificate for use // by the client. func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile string) ClientOption { @@ -332,19 +318,13 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin } o.x5cCert = certs[0] - o.x5cSubject = o.x5cCert.Subject.CommonName - - for _, e := range o.x5cCert.Extensions { - if e.Id.Equal(stepOIDProvisioner) { - var prov stepProvisionerASN1 - if _, err := asn1.Unmarshal(e.Value, &prov); err != nil { - return errors.Wrap(err, "error unmarshaling provisioner OID from certificate") - } - o.x5cIssuer = string(prov.Name) - } - } - if o.x5cIssuer == "" { - return errors.New("provisioner extension not found in certificate") + switch leaf := certs[0]; { + case leaf.Subject.CommonName != "": + o.x5cSubject = leaf.Subject.CommonName + case len(leaf.DNSNames) > 0: + o.x5cSubject = leaf.DNSNames[0] + case len(leaf.EmailAddresses) > 0: + o.x5cSubject = leaf.EmailAddresses[0] } return nil @@ -445,16 +425,18 @@ func parseEndpoint(endpoint string) (*url.URL, error) { } // ProvisionerOption is the type of options passed to the Provisioner method. -type ProvisionerOption func(o *provisionerOptions) error +type ProvisionerOption func(o *ProvisionerOptions) error -type provisionerOptions struct { - cursor string - limit int - id string - name string +// ProvisionerOptions stores options for the provisioner CRUD API. +type ProvisionerOptions struct { + Cursor string + Limit int + ID string + Name string } -func (o *provisionerOptions) apply(opts []ProvisionerOption) (err error) { +// Apply caches provisioner options on a struct for later use. +func (o *ProvisionerOptions) Apply(opts []ProvisionerOption) (err error) { for _, fn := range opts { if err = fn(o); err != nil { return @@ -463,51 +445,51 @@ func (o *provisionerOptions) apply(opts []ProvisionerOption) (err error) { return } -func (o *provisionerOptions) rawQuery() string { +func (o *ProvisionerOptions) rawQuery() string { v := url.Values{} - if len(o.cursor) > 0 { - v.Set("cursor", o.cursor) + if o.Cursor != "" { + v.Set("cursor", o.Cursor) } - if o.limit > 0 { - v.Set("limit", strconv.Itoa(o.limit)) + if o.Limit > 0 { + v.Set("limit", strconv.Itoa(o.Limit)) } - if len(o.id) > 0 { - v.Set("id", o.id) + if o.ID != "" { + v.Set("id", o.ID) } - if len(o.name) > 0 { - v.Set("name", o.name) + if o.Name != "" { + v.Set("name", o.Name) } return v.Encode() } // WithProvisionerCursor will request the provisioners starting with the given cursor. func WithProvisionerCursor(cursor string) ProvisionerOption { - return func(o *provisionerOptions) error { - o.cursor = cursor + return func(o *ProvisionerOptions) error { + o.Cursor = cursor return nil } } // WithProvisionerLimit will request the given number of provisioners. func WithProvisionerLimit(limit int) ProvisionerOption { - return func(o *provisionerOptions) error { - o.limit = limit + return func(o *ProvisionerOptions) error { + o.Limit = limit return nil } } // WithProvisionerID will request the given provisioner. func WithProvisionerID(id string) ProvisionerOption { - return func(o *provisionerOptions) error { - o.id = id + return func(o *ProvisionerOptions) error { + o.ID = id return nil } } // WithProvisionerName will request the given provisioner. func WithProvisionerName(name string) ProvisionerOption { - return func(o *provisionerOptions) error { - o.name = name + return func(o *ProvisionerOptions) error { + o.Name = name return nil } } @@ -830,8 +812,8 @@ retry: // paginate the provisioners. func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { var retried bool - o := new(provisionerOptions) - if err := o.apply(opts); err != nil { + o := new(ProvisionerOptions) + if err := o.Apply(opts); err != nil { return nil, err } u := c.endpoint.ResolveReference(&url.URL{ diff --git a/ca/tls.go b/ca/tls.go index 7954cbdf..57440bad 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "os" + "runtime" "time" "github.com/pkg/errors" @@ -288,10 +289,20 @@ func getDefaultDialer() *net.Dialer { // transport for HTTP/2. func getDefaultTransport(tlsConfig *tls.Config) *http.Transport { var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) - if mTLSDialContext == nil { + switch { + case runtime.GOOS == "js" && runtime.GOARCH == "wasm": + // when running in js/wasm and using the default dialer context all requests + // performed by the CA client resulted in a "protocol not supported" error. + // By setting the dial context to nil requests will be handled by the browser + // fetch API instead. Currently this will always set the dial context to nil, + // but we could implement some additional logic similar to what's found in + // https://github.com/golang/go/pull/46923/files to support a different dial + // context if it is available, required and expected to work. + dialContext = nil + case mTLSDialContext == nil: d := getDefaultDialer() dialContext = d.DialContext - } else { + default: dialContext = mTLSDialContext() } return &http.Transport{ diff --git a/ca/tls_test.go b/ca/tls_test.go index 93dbe9b3..946a6cb5 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -10,6 +10,7 @@ import ( "encoding/hex" "io" "log" + "net" "net/http" "net/http/httptest" "reflect" @@ -77,7 +78,12 @@ func startCATestServer() *httptest.Server { panic(err) } // Use a httptest.Server instead - return startTestServer(ca.srv.TLSConfig, ca.srv.Handler) + srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) + baseContext := buildContext(ca.auth, nil, nil, nil) + srv.Config.BaseContext = func(net.Listener) context.Context { + return baseContext + } + return srv } func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index 3fc34208..f69f933b 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -3,6 +3,7 @@ package apiv1 import ( "crypto" "crypto/x509" + "encoding/json" "github.com/pkg/errors" "github.com/smallstep/certificates/kms" @@ -15,8 +16,9 @@ type Options struct { Type string `json:"type"` // CertificateAuthority reference: - // In StepCAS the value is the CA url, e.g. "https://ca.smallstep.com:9000". + // In StepCAS the value is the CA url, e.g., "https://ca.smallstep.com:9000". // In CloudCAS the format is "projects/*/locations/*/certificateAuthorities/*". + // In VaultCAS the value is the url, e.g., "https://vault.smallstep.com". CertificateAuthority string `json:"certificateAuthority,omitempty"` // CertificateAuthorityFingerprint is the root fingerprint used to @@ -69,6 +71,9 @@ type Options struct { CaPool string `json:"-"` CaPoolTier string `json:"-"` GCSBucket string `json:"-"` + + // Generic structure to configure any CAS + Config json.RawMessage `json:"config,omitempty"` } // CertificateIssuer contains the properties used to use the StepCAS certificate diff --git a/cas/apiv1/services.go b/cas/apiv1/services.go index ce7d0364..43f95d81 100644 --- a/cas/apiv1/services.go +++ b/cas/apiv1/services.go @@ -51,6 +51,8 @@ const ( CloudCAS = "cloudcas" // StepCAS is a CertificateAuthorityService using another step-ca instance. StepCAS = "stepcas" + // VaultCAS is a CertificateAuthorityService using Hasicorp Vault PKI. + VaultCAS = "vaultcas" ) // String returns a string from the type. It will always return the lower case diff --git a/cas/cloudcas/certificate_test.go b/cas/cloudcas/certificate_test.go index 8bf67fb6..0cabdf5b 100644 --- a/cas/cloudcas/certificate_test.go +++ b/cas/cloudcas/certificate_test.go @@ -112,6 +112,7 @@ func Test_createPublicKey(t *testing.T) { t.Fatal(err) } ecCert := mustParseCertificate(t, testLeafCertificate) + ecCertPublicKey := ecCert.PublicKey.(*ecdsa.PublicKey) rsaCert := mustParseCertificate(t, testRSACertificate) type args struct { key crypto.PublicKey @@ -132,9 +133,14 @@ func Test_createPublicKey(t *testing.T) { }, false}, {"fail ed25519", args{edpub}, nil, true}, {"fail ec marshal", args{&ecdsa.PublicKey{ - Curve: &elliptic.CurveParams{Name: "FOO", BitSize: 256}, - X: ecCert.PublicKey.(*ecdsa.PublicKey).X, - Y: ecCert.PublicKey.(*ecdsa.PublicKey).Y, + Curve: &elliptic.CurveParams{ + Name: "FOO", + BitSize: 256, + P: ecCertPublicKey.Params().P, + B: ecCertPublicKey.Params().B, + }, + X: ecCertPublicKey.X, + Y: ecCertPublicKey.Y, }}, nil, true}, } for _, tt := range tests { diff --git a/cas/cloudcas/cloudcas.go b/cas/cloudcas/cloudcas.go index e3e956a9..34ff8506 100644 --- a/cas/cloudcas/cloudcas.go +++ b/cas/cloudcas/cloudcas.go @@ -32,7 +32,9 @@ func init() { var now = time.Now // The actual regular expression that matches a certificate authority is: -// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$ +// +// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$ +// // But we will allow a more flexible one to fail if this changes. var caRegexp = regexp.MustCompile("^projects/[^/]+/locations/[^/]+/caPools/[^/]+/certificateAuthorities/[^/]+$") diff --git a/cas/cloudcas/mock_client_test.go b/cas/cloudcas/mock_client_test.go index de5c2acb..90d1a2f9 100644 --- a/cas/cloudcas/mock_client_test.go +++ b/cas/cloudcas/mock_client_test.go @@ -5,12 +5,13 @@ package cloudcas import ( - privateca "cloud.google.com/go/security/privateca/apiv1" context "context" + reflect "reflect" + + privateca "cloud.google.com/go/security/privateca/apiv1" gomock "github.com/golang/mock/gomock" gax "github.com/googleapis/gax-go/v2" privateca0 "google.golang.org/genproto/googleapis/cloud/security/privateca/v1" - reflect "reflect" ) // MockCertificateAuthorityClient is a mock of CertificateAuthorityClient interface diff --git a/cas/cloudcas/mock_operation_server_test.go b/cas/cloudcas/mock_operation_server_test.go index ee2743d4..43dfa713 100644 --- a/cas/cloudcas/mock_operation_server_test.go +++ b/cas/cloudcas/mock_operation_server_test.go @@ -6,10 +6,11 @@ package cloudcas import ( context "context" + reflect "reflect" + gomock "github.com/golang/mock/gomock" longrunning "google.golang.org/genproto/googleapis/longrunning" emptypb "google.golang.org/protobuf/types/known/emptypb" - reflect "reflect" ) // MockOperationsServer is a mock of OperationsServer interface diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 0b2270bb..b03faad7 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -3,6 +3,7 @@ package softcas import ( "context" "crypto" + "crypto/rsa" "crypto/rand" "crypto/x509" "time" @@ -260,6 +261,8 @@ func createCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, if template.SignatureAlgorithm == 0 { if sa, ok := signer.(apiv1.SignatureAlgorithmGetter); ok { template.SignatureAlgorithm = sa.SignatureAlgorithm() + } else if _, ok := parent.PublicKey.(*rsa.PublicKey); ok { + template.SignatureAlgorithm = parent.SignatureAlgorithm } } return x509util.CreateCertificate(template, parent, pub, signer) diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index b4f5b440..0651ab4d 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto" "crypto/rand" + "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "fmt" @@ -350,6 +351,67 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { } } +func TestSoftCAS_CreateCertificate_pss(t *testing.T) { + signer, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + template := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, + PublicKey: signer.Public(), + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 0, + SerialNumber: big.NewInt(1234), + SignatureAlgorithm: x509.SHA256WithRSAPSS, + NotBefore: now, + NotAfter: now.Add(24 * time.Hour), + } + + iss, err := x509util.CreateCertificate(template, template, signer.Public(), signer) + if err != nil { + t.Fatal(err) + } + if iss.SignatureAlgorithm != x509.SHA256WithRSAPSS { + t.Errorf("Certificate.SignatureAlgorithm = %v, want %v", iss.SignatureAlgorithm, x509.SHA256WithRSAPSS) + } + + c := &SoftCAS{ + CertificateChain: []*x509.Certificate{iss}, + Signer: signer, + } + cert, err := c.CreateCertificate(&apiv1.CreateCertificateRequest{ + Template: &x509.Certificate{ + Subject: pkix.Name{CommonName: "test.smallstep.com"}, + DNSNames: []string{"test.smallstep.com"}, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + PublicKey: testSigner.Public(), + SerialNumber: big.NewInt(1234), + }, + Lifetime: time.Hour, Backdate: time.Minute, + }) + if err != nil { + t.Fatalf("SoftCAS.CreateCertificate() error = %v", err) + } + if cert.Certificate.SignatureAlgorithm != x509.SHA256WithRSAPSS { + t.Errorf("Certificate.SignatureAlgorithm = %v, want %v", iss.SignatureAlgorithm, x509.SHA256WithRSAPSS) + } + + pool := x509.NewCertPool() + pool.AddCert(iss) + if _, err = cert.Certificate.Verify(x509.VerifyOptions{ + CurrentTime: time.Now(), + Roots: pool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + }); err != nil { + t.Errorf("Certificate.Verify() error = %v", err) + } +} + func TestSoftCAS_RenewCertificate(t *testing.T) { mockNow(t) diff --git a/cas/vaultcas/auth/approle/approle.go b/cas/vaultcas/auth/approle/approle.go new file mode 100644 index 00000000..118afb10 --- /dev/null +++ b/cas/vaultcas/auth/approle/approle.go @@ -0,0 +1,67 @@ +package approle + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/hashicorp/vault/api/auth/approle" +) + +// AuthOptions defines the configuration options added using the +// VaultOptions.AuthOptions field when AuthType is approle +type AuthOptions struct { + RoleID string `json:"roleID,omitempty"` + SecretID string `json:"secretID,omitempty"` + SecretIDFile string `json:"secretIDFile,omitempty"` + SecretIDEnv string `json:"secretIDEnv,omitempty"` + IsWrappingToken bool `json:"isWrappingToken,omitempty"` +} + +func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.AppRoleAuth, error) { + var opts *AuthOptions + + err := json.Unmarshal(options, &opts) + if err != nil { + return nil, fmt.Errorf("error decoding AppRole auth options: %w", err) + } + + var approleAuth *approle.AppRoleAuth + + var loginOptions []approle.LoginOption + if mountPath != "" { + loginOptions = append(loginOptions, approle.WithMountPath(mountPath)) + } + if opts.IsWrappingToken { + loginOptions = append(loginOptions, approle.WithWrappingToken()) + } + + if opts.RoleID == "" { + return nil, errors.New("you must set roleID") + } + + var sid approle.SecretID + switch { + case opts.SecretID != "" && opts.SecretIDFile == "" && opts.SecretIDEnv == "": + sid = approle.SecretID{ + FromString: opts.SecretID, + } + case opts.SecretIDFile != "" && opts.SecretID == "" && opts.SecretIDEnv == "": + sid = approle.SecretID{ + FromFile: opts.SecretIDFile, + } + case opts.SecretIDEnv != "" && opts.SecretIDFile == "" && opts.SecretID == "": + sid = approle.SecretID{ + FromEnv: opts.SecretIDEnv, + } + default: + return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv") + } + + approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + return approleAuth, nil +} diff --git a/cas/vaultcas/auth/approle/approle_test.go b/cas/vaultcas/auth/approle/approle_test.go new file mode 100644 index 00000000..28b7b7f7 --- /dev/null +++ b/cas/vaultcas/auth/approle/approle_test.go @@ -0,0 +1,195 @@ +package approle + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + vault "github.com/hashicorp/vault/api" +) + +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.0000" + } + }`) + case r.RequestURI == "/v1/auth/custom-approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.9999" + } + }`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } + + config := vault.DefaultConfig() + config.Address = srv.URL + + client, err := vault.NewClient(config) + if err != nil { + srv.Close() + t.Fatal(err) + } + + return u, client +} + +func TestApprole_LoginMountPaths(t *testing.T) { + caURL, _ := testCAHelper(t) + + config := vault.DefaultConfig() + config.Address = caURL.String() + client, _ := vault.NewClient(config) + + tests := []struct { + name string + mountPath string + token string + }{ + { + name: "ok default mount path", + mountPath: "", + token: "hvs.0000", + }, + { + name: "ok explicit mount path", + mountPath: "approle", + token: "hvs.0000", + }, + { + name: "ok custom mount path", + mountPath: "custom-approle", + token: "hvs.9999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`)) + if err != nil { + t.Errorf("NewApproleAuthMethod() error = %v", err) + return + } + + secret, err := client.Auth().Login(context.Background(), method) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + token, _ := secret.TokenID() + if token != tt.token { + t.Errorf("Token error got %v, expected %v", token, tt.token) + return + } + }) + } +} + +func TestApprole_NewApproleAuthMethod(t *testing.T) { + tests := []struct { + name string + mountPath string + raw string + wantErr bool + }{ + { + "ok secret-id string", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000"}`, + false, + }, + { + "ok secret-id string and wrapped", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, + false, + }, + { + "ok secret-id string and wrapped with custom mountPath", + "approle2", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, + false, + }, + { + "ok secret-id file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, + false, + }, + { + "ok secret-id env", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + false, + }, + { + "fail mandatory role-id", + "", + `{}`, + true, + }, + { + "fail mandatory secret-id any", + "", + `{"RoleID": "0000-0000-0000-0000"}`, + true, + }, + { + "fail multiple secret-id types id and env", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + { + "fail multiple secret-id types id and file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, + true, + }, + { + "fail multiple secret-id types env and file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + { + "fail multiple secret-id types all", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("Approle.NewApproleAuthMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/cas/vaultcas/auth/kubernetes/kubernetes.go b/cas/vaultcas/auth/kubernetes/kubernetes.go new file mode 100644 index 00000000..267bcdca --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/kubernetes.go @@ -0,0 +1,49 @@ +package kubernetes + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/hashicorp/vault/api/auth/kubernetes" +) + +// AuthOptions defines the configuration options added using the +// VaultOptions.AuthOptions field when AuthType is kubernetes +type AuthOptions struct { + Role string `json:"role,omitempty"` + TokenPath string `json:"tokenPath,omitempty"` +} + +func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubernetes.KubernetesAuth, error) { + var opts *AuthOptions + + err := json.Unmarshal(options, &opts) + if err != nil { + return nil, fmt.Errorf("error decoding Kubernetes auth options: %w", err) + } + + var kubernetesAuth *kubernetes.KubernetesAuth + + var loginOptions []kubernetes.LoginOption + if mountPath != "" { + loginOptions = append(loginOptions, kubernetes.WithMountPath(mountPath)) + } + if opts.TokenPath != "" { + loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath)) + } + + if opts.Role == "" { + return nil, errors.New("you must set role") + } + + kubernetesAuth, err = kubernetes.NewKubernetesAuth( + opts.Role, + loginOptions..., + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + return kubernetesAuth, nil +} diff --git a/cas/vaultcas/auth/kubernetes/kubernetes_test.go b/cas/vaultcas/auth/kubernetes/kubernetes_test.go new file mode 100644 index 00000000..55be904d --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/kubernetes_test.go @@ -0,0 +1,149 @@ +package kubernetes + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "path" + "path/filepath" + "runtime" + "testing" + + vault "github.com/hashicorp/vault/api" +) + +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/kubernetes/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.0000" + } + }`) + case r.RequestURI == "/v1/auth/custom-kubernetes/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.9999" + } + }`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } + + config := vault.DefaultConfig() + config.Address = srv.URL + + client, err := vault.NewClient(config) + if err != nil { + srv.Close() + t.Fatal(err) + } + + return u, client +} + +func TestApprole_LoginMountPaths(t *testing.T) { + caURL, _ := testCAHelper(t) + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + + config := vault.DefaultConfig() + config.Address = caURL.String() + client, _ := vault.NewClient(config) + + tests := []struct { + name string + mountPath string + token string + }{ + { + name: "ok default mount path", + mountPath: "", + token: "hvs.0000", + }, + { + name: "ok explicit mount path", + mountPath: "kubernetes", + token: "hvs.0000", + }, + { + name: "ok custom mount path", + mountPath: "custom-kubernetes", + token: "hvs.9999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(`{"role": "SomeRoleName", "tokenPath": "`+tokenPath+`"}`)) + if err != nil { + t.Errorf("NewApproleAuthMethod() error = %v", err) + return + } + + secret, err := client.Auth().Login(context.Background(), method) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + token, _ := secret.TokenID() + if token != tt.token { + t.Errorf("Token error got %v, expected %v", token, tt.token) + return + } + }) + } +} + +func TestApprole_NewApproleAuthMethod(t *testing.T) { + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + + tests := []struct { + name string + mountPath string + raw string + wantErr bool + }{ + { + "ok secret-id string", + "", + `{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}`, + false, + }, + { + "fail mandatory role", + "", + `{}`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("Kubernetes.NewKubernetesAuthMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/cas/vaultcas/auth/kubernetes/token b/cas/vaultcas/auth/kubernetes/token new file mode 100644 index 00000000..6745be67 --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/token @@ -0,0 +1 @@ +token \ No newline at end of file diff --git a/cas/vaultcas/vaultcas.go b/cas/vaultcas/vaultcas.go new file mode 100644 index 00000000..a5658620 --- /dev/null +++ b/cas/vaultcas/vaultcas.go @@ -0,0 +1,329 @@ +package vaultcas + +import ( + "bytes" + "context" + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "math/big" + "strings" + "time" + + "github.com/smallstep/certificates/cas/apiv1" + "github.com/smallstep/certificates/cas/vaultcas/auth/approle" + "github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes" + + vault "github.com/hashicorp/vault/api" +) + +func init() { + apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { + return New(ctx, opts) + }) +} + +// VaultOptions defines the configuration options added using the +// apiv1.Options.Config field. +type VaultOptions struct { + PKIMountPath string `json:"pkiMountPath,omitempty"` + PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` + PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` + PKIRoleEC string `json:"pkiRoleEC,omitempty"` + PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` + AuthType string `json:"authType,omitempty"` + AuthMountPath string `json:"authMountPath,omitempty"` + AuthOptions json.RawMessage `json:"authOptions,omitempty"` +} + +// VaultCAS implements a Certificate Authority Service using Hashicorp Vault. +type VaultCAS struct { + client *vault.Client + config VaultOptions + fingerprint string +} + +type certBundle struct { + leaf *x509.Certificate + intermediates []*x509.Certificate + root *x509.Certificate +} + +// New creates a new CertificateAuthorityService implementation +// using Hashicorp Vault +func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { + if opts.CertificateAuthority == "" { + return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty") + } + + if opts.CertificateAuthorityFingerprint == "" { + return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty") + } + + vc, err := loadOptions(opts.Config) + if err != nil { + return nil, err + } + + config := vault.DefaultConfig() + config.Address = opts.CertificateAuthority + + client, err := vault.NewClient(config) + if err != nil { + return nil, fmt.Errorf("unable to initialize vault client: %w", err) + } + + var method vault.AuthMethod + switch vc.AuthType { + case "kubernetes": + method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions) + case "approle": + method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions) + default: + return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType) + } + if err != nil { + return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err) + } + + authInfo, err := client.Auth().Login(ctx, method) + if err != nil { + return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err) + } + if authInfo == nil { + return nil, errors.New("no auth info was returned after login") + } + + return &VaultCAS{ + client: client, + config: *vc, + fingerprint: opts.CertificateAuthorityFingerprint, + }, nil +} + +// CreateCertificate signs a new certificate using Hashicorp Vault. +func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { + switch { + case req.CSR == nil: + return nil, errors.New("createCertificate `csr` cannot be nil") + case req.Lifetime == 0: + return nil, errors.New("createCertificate `lifetime` cannot be 0") + } + + cert, chain, err := v.createCertificate(req.CSR, req.Lifetime) + if err != nil { + return nil, err + } + + return &apiv1.CreateCertificateResponse{ + Certificate: cert, + CertificateChain: chain, + }, nil +} + +// GetCertificateAuthority returns the root certificate of the certificate +// authority using the configured fingerprint. +func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { + secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain") + if err != nil { + return nil, fmt.Errorf("error reading ca chain: %w", err) + } + if secret == nil { + return nil, errors.New("error reading ca chain: response is empty") + } + + chain, ok := secret.Data["certificate"].(string) + if !ok { + return nil, errors.New("error unmarshaling vault response: certificate not found") + } + + cert, err := getCertificateBundle(chain) + if err != nil { + return nil, err + } + if cert.root == nil { + return nil, errors.New("error unmarshaling vault response: root certificate not found") + } + + sum := sha256.Sum256(cert.root.Raw) + if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) { + return nil, errors.New("error verifying vault root: fingerprint does not match") + } + + return &apiv1.GetCertificateAuthorityResponse{ + RootCertificate: cert.root, + }, nil +} + +// RenewCertificate will always return a non-implemented error as renewals +// are not supported yet. +func (v *VaultCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { + return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"} +} + +// RevokeCertificate revokes a certificate by serial number. +func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { + if req.SerialNumber == "" && req.Certificate == nil { + return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required") + } + + var sn *big.Int + if req.SerialNumber != "" { + var ok bool + if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok { + return nil, fmt.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber) + } + } else { + sn = req.Certificate.SerialNumber + } + + vaultReq := map[string]interface{}{ + "serial_number": formatSerialNumber(sn), + } + _, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq) + if err != nil { + return nil, fmt.Errorf("error revoking certificate: %w", err) + } + + return &apiv1.RevokeCertificateResponse{ + Certificate: req.Certificate, + CertificateChain: nil, + }, nil +} + +func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { + var vaultPKIRole string + + switch { + case cr.PublicKeyAlgorithm == x509.RSA: + vaultPKIRole = v.config.PKIRoleRSA + case cr.PublicKeyAlgorithm == x509.ECDSA: + vaultPKIRole = v.config.PKIRoleEC + case cr.PublicKeyAlgorithm == x509.Ed25519: + vaultPKIRole = v.config.PKIRoleEd25519 + default: + return nil, nil, fmt.Errorf("unsupported public key algorithm %v", cr.PublicKeyAlgorithm) + } + + vaultReq := map[string]interface{}{ + "csr": string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: cr.Raw, + })), + "format": "pem_bundle", + "ttl": lifetime.Seconds(), + } + + secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq) + if err != nil { + return nil, nil, fmt.Errorf("error signing certificate: %w", err) + } + if secret == nil { + return nil, nil, errors.New("error signing certificate: response is empty") + } + + chain, ok := secret.Data["certificate"].(string) + if !ok { + return nil, nil, errors.New("error unmarshaling vault response: certificate not found") + } + + cert, err := getCertificateBundle(chain) + if err != nil { + return nil, nil, err + } + + // Return certificate and certificate chain + return cert.leaf, cert.intermediates, nil +} + +func loadOptions(config json.RawMessage) (*VaultOptions, error) { + // setup default values + vc := VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "default", + } + + err := json.Unmarshal(config, &vc) + if err != nil { + return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) + } + + if vc.PKIRoleRSA == "" { + vc.PKIRoleRSA = vc.PKIRoleDefault + } + if vc.PKIRoleEC == "" { + vc.PKIRoleEC = vc.PKIRoleDefault + } + if vc.PKIRoleEd25519 == "" { + vc.PKIRoleEd25519 = vc.PKIRoleDefault + } + + return &vc, nil +} + +func parseCertificates(pemCert string) []*x509.Certificate { + var certs []*x509.Certificate + rest := []byte(pemCert) + var block *pem.Block + for { + block, rest = pem.Decode(rest) + if block == nil { + break + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + break + } + certs = append(certs, cert) + } + return certs +} + +func getCertificateBundle(chain string) (*certBundle, error) { + var root *x509.Certificate + var leaf *x509.Certificate + var intermediates []*x509.Certificate + for _, cert := range parseCertificates(chain) { + switch { + case isRoot(cert): + root = cert + case cert.BasicConstraintsValid && cert.IsCA: + intermediates = append(intermediates, cert) + default: + leaf = cert + } + } + + certificate := &certBundle{ + root: root, + leaf: leaf, + intermediates: intermediates, + } + + return certificate, nil +} + +// isRoot returns true if the given certificate is a root certificate. +func isRoot(cert *x509.Certificate) bool { + if cert.BasicConstraintsValid && cert.IsCA { + return cert.CheckSignatureFrom(cert) == nil + } + return false +} + +// formatSerialNumber formats a serial number to a dash-separated hexadecimal +// string. +func formatSerialNumber(sn *big.Int) string { + var ret bytes.Buffer + for _, b := range sn.Bytes() { + if ret.Len() > 0 { + ret.WriteString("-") + } + ret.WriteString(hex.EncodeToString([]byte{b})) + } + return ret.String() +} diff --git a/cas/vaultcas/vaultcas_test.go b/cas/vaultcas/vaultcas_test.go new file mode 100644 index 00000000..0ea0c4b1 --- /dev/null +++ b/cas/vaultcas/vaultcas_test.go @@ -0,0 +1,524 @@ +package vaultcas + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" + "time" + + vault "github.com/hashicorp/vault/api" + "github.com/smallstep/certificates/cas/apiv1" + "go.step.sm/crypto/pemutil" +) + +var ( + testCertificateSigned = `-----BEGIN CERTIFICATE----- +MIIB/DCCAaKgAwIBAgIQHHFuGMz0cClfde5kqP5prTAKBggqhkjOPQQDAjAqMSgw +JgYDVQQDEx9Hb29nbGUgQ0FTIFRlc3QgSW50ZXJtZWRpYXRlIENBMB4XDTIwMDkx +NTAwMDQ0M1oXDTMwMDkxMzAwMDQ0MFowHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0 +ZXAuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMqNCiXMvbn74LsHzRv+8 +17m9vEzH6RHrg3m82e0uEc36+fZWV/zJ9SKuONmnl5VP79LsjL5SVH0RDj73U2XO +DKOBtjCBszAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG +AQUFBwMCMB0GA1UdDgQWBBRTA2cTs7PCNjnps/+T0dS8diqv0DAfBgNVHSMEGDAW +gBRIOVqyLDSlErJLuWWEvRm5UU1r1TBCBgwrBgEEAYKkZMYoQAIEMjAwEwhjbG91 +ZGNhcxMkZDhkMThhNjgtNTI5Ni00YWYzLWFlNGItMmY4NzdkYTNmYmQ5MAoGCCqG +SM49BAMCA0gAMEUCIGxl+pqJ50WYWUqK2l4V1FHoXSi0Nht5kwTxFxnWZu1xAiEA +zemu3bhWLFaGg3s8i+HTEhw4RqkHP74vF7AVYp88bAw= +-----END CERTIFICATE-----` + testCertificateCsrEc = `-----BEGIN CERTIFICATE REQUEST----- +MIHoMIGPAgEAMA0xCzAJBgNVBAMTAkVDMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD +QgAEUVVVZGD6eUrB20T/qrjKZoYzseQ18AIm9jtUNpQn5hIClpdk2zKy5bja3iUa +nmqRKCIz/B/MU55zuNDeckqqX6AgMB4GCSqGSIb3DQEJDjERMA8wDQYDVR0RBAYw +BIICRUMwCgYIKoZIzj0EAwIDSAAwRQIhAJxpWyH7cctbzcnK1JBWDAmc/G61bq9y +otHrQDfYvS8bAiBVGQz2cfO2SqhvkkQbOqWUFjk1wHzISvlTjyc3IJ7FLw== +-----END CERTIFICATE REQUEST-----` + testCertificateCsrRsa = `-----BEGIN CERTIFICATE REQUEST----- +MIICdDCCAVwCAQAwDjEMMAoGA1UEAxMDUlNBMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAxe5XLSZrTCzzH0FJCXvZwghAY5XztzjseSRcm0jL8Q7nvNWi +Vpu1n7EmfVU9b8sbvtVYqMQV+hMdj2C/NIw4Yal4Wg+BgunYOrRqfY7oDm4csG0R +g5v0h2yQw14kqVrftNyojX0Nv/CPboCGl64PA9zsEXQTB3Y1AUWrUGPiBWNACYIH +mjv70Ay9JKBBAqov38I7nka/RgYAl5DCHzU2vvODriBYFWagnzycA4Ni5EKTz93W +SPdDEhkWi3ugUqal3SvgHl8re+8d7ghLn85Y3TFuyU2nSMDPHaymsiNFw1mRwOw3 +lAseidHJkPQs7q6FiYXaeqetf1j/gw0n23ZogwIDAQABoCEwHwYJKoZIhvcNAQkO +MRIwEDAOBgNVHREEBzAFggNSU0EwDQYJKoZIhvcNAQELBQADggEBALnO5vcDkgGO +GQoSINa2NmNFxAtYQGYHok5KXYX+S+etmOmDrmrhsl/pSjN3GPCPlThFlbLStB70 +oJw67nEjGf0hPEBVlm+qFUsYQ1KGRZFAWDSMQ//pU225XFDCmlzHfV7gZjSkP9GN +Gc5VECOzx6hAFR+IEL/l/1GG5HHkPPrr/8OvuIfm2V5ofYmhsXMVVYH52qPofMAV +B8UdNnZK3nyLdUqVd+PYUUJmN4bJ8YfxofKKgbLkhvkKp4OZ9vkwUi2+61NdHTf2 +wIauOyxEoTlJpU6oA/sxu/2Ht2DP+8y6mognLBuKklE/VH3/2iqQWyg1NV5hyg3b +loVSdLsIh5Y= +-----END CERTIFICATE REQUEST-----` + testCertificateCsrEd25519 = `-----BEGIN CERTIFICATE REQUEST----- +MIGuMGICAQAwDjEMMAoGA1UEAxMDT0tQMCowBQYDK2VwAyEAopc6daK4zYR6BDAM +pV/v53oR/ewbtrkHZQkN/amFMLagITAfBgkqhkiG9w0BCQ4xEjAQMA4GA1UdEQQH +MAWCA09LUDAFBgMrZXADQQDJi47MAgl/WKAz+V/kDu1k/zbKk1nrHHAUonbofHUW +M6ihSD43+awq3BPeyPbToeH5orSH9l3MuTfbxPb5BVEH +-----END CERTIFICATE REQUEST-----` + testRootCertificate = `-----BEGIN CERTIFICATE----- +MIIBeDCCAR+gAwIBAgIQcXWWjtSZ/PAyH8D1Ou4L9jAKBggqhkjOPQQDAjAbMRkw +FwYDVQQDExBDbG91ZENBUyBSb290IENBMB4XDTIwMTAyNzIyNTM1NFoXDTMwMTAy +NzIyNTM1NFowGzEZMBcGA1UEAxMQQ2xvdWRDQVMgUm9vdCBDQTBZMBMGByqGSM49 +AgEGCCqGSM49AwEHA0IABIySHA4b78Yu4LuGhZIlv/PhNwXz4ZoV1OUZQ0LrK3vj +B13O12DLZC5uj1z3kxdQzXUttSbtRv49clMpBiTpsZKjRTBDMA4GA1UdDwEB/wQE +AwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSZ+t9RMHbFTl5BatM3 +5bJlHPOu3DAKBggqhkjOPQQDAgNHADBEAiASah6gg0tVM3WI0meCQ4SEKk7Mjhbv ++SmhuZHWV1QlXQIgRXNyWcpVUrAoG6Uy1KQg07LDpF5dFeK9InrDxSJAkVo= +-----END CERTIFICATE-----` + testRootFingerprint = `62e816cbac5c501b7705e18415503852798dfbcd67062f06bcb4af67c290e3c8` +) + +func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate { + t.Helper() + crt := parseCertificates(pemCert)[0] + return crt +} + +func mustParseCertificateRequest(t *testing.T, pemData string) *x509.CertificateRequest { + t.Helper() + csr, err := pemutil.ParseCertificateRequest([]byte(pemData)) + if err != nil { + t.Fatal(err) + } + return csr +} + +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + writeJSON := func(w http.ResponseWriter, v interface{}) { + _ = json.NewEncoder(w).Encode(v) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "98a4c7ab-b1fe-361b-ba0b-e307aacfd587" + } + }`) + case r.RequestURI == "/v1/pki/sign/ec": + w.WriteHeader(http.StatusOK) + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} + writeJSON(w, cert) + return + case r.RequestURI == "/v1/pki/sign/rsa": + w.WriteHeader(http.StatusOK) + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} + writeJSON(w, cert) + return + case r.RequestURI == "/v1/pki/sign/ed25519": + w.WriteHeader(http.StatusOK) + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} + writeJSON(w, cert) + return + case r.RequestURI == "/v1/pki/cert/ca_chain": + w.WriteHeader(http.StatusOK) + cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} + writeJSON(w, cert) + return + case r.RequestURI == "/v1/pki/revoke": + buf := new(bytes.Buffer) + buf.ReadFrom(r.Body) + m := make(map[string]string) + json.Unmarshal(buf.Bytes(), &m) + switch { + case m["serial_number"] == "1c-71-6e-18-cc-f4-70-29-5f-75-ee-64-a8-fe-69-ad": + w.WriteHeader(http.StatusOK) + return + case m["serial_number"] == "01-e2-40": + w.WriteHeader(http.StatusOK) + return + // both + case m["serial_number"] == "01-34-3e": + w.WriteHeader(http.StatusOK) + return + default: + w.WriteHeader(http.StatusNotFound) + } + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } + + config := vault.DefaultConfig() + config.Address = srv.URL + + client, err := vault.NewClient(config) + if err != nil { + srv.Close() + t.Fatal(err) + } + + return u, client +} + +func TestNew_register(t *testing.T) { + caURL, _ := testCAHelper(t) + + fn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.VaultCAS) + if !ok { + t.Errorf("apiv1.Register() ok = %v, want true", ok) + return + } + _, err := fn(context.Background(), apiv1.Options{ + CertificateAuthority: caURL.String(), + CertificateAuthorityFingerprint: testRootFingerprint, + Config: json.RawMessage(`{ + "AuthType": "approle", + "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false} + }`), + }) + + if err != nil { + t.Errorf("New() error = %v", err) + return + } +} + +func TestVaultCAS_CreateCertificate(t *testing.T) { + _, client := testCAHelper(t) + + options := VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + } + + type fields struct { + client *vault.Client + options VaultOptions + } + + type args struct { + req *apiv1.CreateCertificateRequest + } + + tests := []struct { + name string + fields fields + args args + want *apiv1.CreateCertificateResponse + wantErr bool + }{ + {"ok ec", fields{client, options}, args{&apiv1.CreateCertificateRequest{ + CSR: mustParseCertificateRequest(t, testCertificateCsrEc), + Lifetime: time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: mustParseCertificate(t, testCertificateSigned), + CertificateChain: nil, + }, false}, + {"ok rsa", fields{client, options}, args{&apiv1.CreateCertificateRequest{ + CSR: mustParseCertificateRequest(t, testCertificateCsrRsa), + Lifetime: time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: mustParseCertificate(t, testCertificateSigned), + CertificateChain: nil, + }, false}, + {"ok ed25519", fields{client, options}, args{&apiv1.CreateCertificateRequest{ + CSR: mustParseCertificateRequest(t, testCertificateCsrEd25519), + Lifetime: time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: mustParseCertificate(t, testCertificateSigned), + CertificateChain: nil, + }, false}, + {"fail CSR", fields{client, options}, args{&apiv1.CreateCertificateRequest{ + CSR: nil, + Lifetime: time.Hour, + }}, nil, true}, + {"fail lifetime", fields{client, options}, args{&apiv1.CreateCertificateRequest{ + CSR: mustParseCertificateRequest(t, testCertificateCsrEc), + Lifetime: 0, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &VaultCAS{ + client: tt.fields.client, + config: tt.fields.options, + } + got, err := c.CreateCertificate(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("VaultCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("VaultCAS.CreateCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVaultCAS_GetCertificateAuthority(t *testing.T) { + caURL, client := testCAHelper(t) + + type fields struct { + client *vault.Client + options VaultOptions + fingerprint string + } + + type args struct { + req *apiv1.GetCertificateAuthorityRequest + } + + options := VaultOptions{ + PKIMountPath: "pki", + } + + rootCert := parseCertificates(testRootCertificate)[0] + + tests := []struct { + name string + fields fields + args args + want *apiv1.GetCertificateAuthorityResponse + wantErr bool + }{ + {"ok", fields{client, options, testRootFingerprint}, args{&apiv1.GetCertificateAuthorityRequest{ + Name: caURL.String(), + }}, &apiv1.GetCertificateAuthorityResponse{ + RootCertificate: rootCert, + }, false}, + {"fail fingerprint", fields{client, options, "fail"}, args{&apiv1.GetCertificateAuthorityRequest{ + Name: caURL.String(), + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &VaultCAS{ + client: tt.fields.client, + fingerprint: tt.fields.fingerprint, + config: tt.fields.options, + } + got, err := s.GetCertificateAuthority(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("VaultCAS.GetCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("VaultCAS.GetCertificateAuthority() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVaultCAS_RevokeCertificate(t *testing.T) { + _, client := testCAHelper(t) + + options := VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + } + + type fields struct { + client *vault.Client + options VaultOptions + } + + type args struct { + req *apiv1.RevokeCertificateRequest + } + + testCrt := parseCertificates(testCertificateSigned)[0] + + tests := []struct { + name string + fields fields + args args + want *apiv1.RevokeCertificateResponse + wantErr bool + }{ + {"ok serial number", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ + SerialNumber: "123456", + Certificate: nil, + }}, &apiv1.RevokeCertificateResponse{}, false}, + {"ok certificate", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ + SerialNumber: "", + Certificate: testCrt, + }}, &apiv1.RevokeCertificateResponse{ + Certificate: testCrt, + }, false}, + {"ok both", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ + SerialNumber: "78910", + Certificate: testCrt, + }}, &apiv1.RevokeCertificateResponse{ + Certificate: testCrt, + }, false}, + {"fail serial string", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ + SerialNumber: "fail", + Certificate: nil, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &VaultCAS{ + client: tt.fields.client, + config: tt.fields.options, + } + got, err := s.RevokeCertificate(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("VaultCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("VaultCAS.RevokeCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVaultCAS_RenewCertificate(t *testing.T) { + _, client := testCAHelper(t) + + options := VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + } + + type fields struct { + client *vault.Client + options VaultOptions + } + + type args struct { + req *apiv1.RenewCertificateRequest + } + + tests := []struct { + name string + fields fields + args args + want *apiv1.RenewCertificateResponse + wantErr bool + }{ + {"not implemented", fields{client, options}, args{&apiv1.RenewCertificateRequest{ + CSR: mustParseCertificateRequest(t, testCertificateCsrEc), + Lifetime: time.Hour, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &VaultCAS{ + client: tt.fields.client, + config: tt.fields.options, + } + got, err := s.RenewCertificate(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("VaultCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("VaultCAS.RenewCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVaultCAS_loadOptions(t *testing.T) { + tests := []struct { + name string + raw string + want *VaultOptions + wantErr bool + }{ + { + "ok mandatory PKIRole PKIRoleEd25519", + `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`, + &VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "role", + PKIRoleEC: "role", + PKIRoleEd25519: "ed25519", + }, + false, + }, + { + "ok mandatory PKIRole PKIRoleEC", + `{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`, + &VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "role", + PKIRoleEC: "ec", + PKIRoleEd25519: "role", + }, + false, + }, + { + "ok mandatory PKIRole PKIRoleRSA", + `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`, + &VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "role", + PKIRoleEd25519: "role", + }, + false, + }, + { + "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519", + `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, + &VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "default", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + }, + false, + }, + { + "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault", + `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, + &VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := loadOptions(json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("VaultCAS.loadOptions() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("VaultCAS.loadOptions() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/step-ca/main.go b/cmd/step-ca/main.go index 96e7fbd5..bc7bf2e3 100644 --- a/cmd/step-ca/main.go +++ b/cmd/step-ca/main.go @@ -40,6 +40,7 @@ import ( _ "github.com/smallstep/certificates/cas/cloudcas" _ "github.com/smallstep/certificates/cas/softcas" _ "github.com/smallstep/certificates/cas/stepcas" + _ "github.com/smallstep/certificates/cas/vaultcas" ) // commit and buildTime are filled in during build by the Makefile diff --git a/cmd/step-pkcs11-init/main.go b/cmd/step-pkcs11-init/main.go index 4dc15799..7a23c664 100644 --- a/cmd/step-pkcs11-init/main.go +++ b/cmd/step-pkcs11-init/main.go @@ -47,6 +47,7 @@ type Config struct { RootFile string KeyFile string Pin string + PinFile string NoCerts bool EnableSSH bool Force bool @@ -74,6 +75,8 @@ func (c *Config) Validate() error { return errors.New("flag `--root-gen` requires flag `--root-key-obj`") case c.RootFile == "" && c.GenerateRoot && c.RootPath == "": return errors.New("flag `--root-gen` requires `--root-cert-path`") + case c.Pin != "" && c.PinFile != "": + return errors.New("Only set one of pin and pin-file") default: if c.RootFile != "" { c.GenerateRoot = false @@ -108,6 +111,7 @@ func main() { var c Config flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.") flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN") + flag.StringVar(&c.PinFile, "pin-file", "", "PKCS #11 PIN File") // Option 1: Generate new root flag.BoolVar(&c.GenerateRoot, "root-gen", true, "Enable the generation of a root key.") flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.") @@ -147,7 +151,18 @@ func main() { // Initialize windows terminal ui.Init() - if u.Get("pin-value") == "" && u.Get("pin-source") == "" && c.Pin == "" { + switch { + case u.Get("pin-value") != "": + case u.Get("pin-source") != "": + case c.Pin != "": + case c.PinFile != "": + content, err := os.ReadFile(c.PinFile) + if err != nil { + fatal(err) + } + c.Pin = string(content) + + default: pin, err := ui.PromptPassword("What is the PKCS#11 PIN?") if err != nil { fatal(err) diff --git a/commands/onboard.go b/commands/onboard.go index ebd468f5..afecba9d 100644 --- a/commands/onboard.go +++ b/commands/onboard.go @@ -23,7 +23,8 @@ import ( // defaultOnboardingURL is the production onboarding url, to use a development // url use: -// export STEP_CA_ONBOARDING_URL=http://localhost:3002/onboarding/ +// +// export STEP_CA_ONBOARDING_URL=http://localhost:3002/onboarding/ const defaultOnboardingURL = "https://api.smallstep.com/onboarding/" type onboardingConfiguration struct { diff --git a/db/db.go b/db/db.go index ccaf4056..b93b23ca 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "crypto/x509" "encoding/json" "strconv" @@ -8,6 +9,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" "golang.org/x/crypto/ssh" @@ -15,6 +17,7 @@ import ( var ( certsTable = []byte("x509_certs") + certsDataTable = []byte("x509_certs_data") revokedCertsTable = []byte("revoked_x509_certs") crlTable = []byte("x509_crl") revokedSSHCertsTable = []byte("revoked_ssh_certs") @@ -52,14 +55,42 @@ type AuthDB interface { Revoke(rci *RevokedCertificateInfo) error RevokeSSH(rci *RevokedCertificateInfo) error GetCertificate(serialNumber string) (*x509.Certificate, error) - StoreCertificate(crt *x509.Certificate) error UseToken(id, tok string) (bool, error) IsSSHHost(name string) (bool, error) - StoreSSHCertificate(crt *ssh.Certificate) error GetSSHHostPrincipals() ([]string, error) Shutdown() error } +type dbKey struct{} + +// NewContext adds the given authority database to the context. +func NewContext(ctx context.Context, db AuthDB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current authority database from the given context. +func FromContext(ctx context.Context) (db AuthDB, ok bool) { + db, ok = ctx.Value(dbKey{}).(AuthDB) + return +} + +// MustFromContext returns the current database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) AuthDB { + if db, ok := FromContext(ctx); !ok { + panic("authority database is not in the context") + } else { + return db + } +} + +// CertificateStorer is an extension of AuthDB that allows to store +// certificates. +type CertificateStorer interface { + StoreCertificate(crt *x509.Certificate) error + StoreSSHCertificate(crt *ssh.Certificate) error +} + // CertificateRevocationListDB is an interface to indicate whether the DB supports CRL generation type CertificateRevocationListDB interface { GetRevokedCertificates() (*[]RevokedCertificateInfo, error) @@ -93,7 +124,7 @@ func New(c *Config) (AuthDB, error) { tables := [][]byte{ revokedCertsTable, certsTable, usedOTTTable, sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable, - revokedSSHCertsTable, crlTable, + revokedSSHCertsTable, certsDataTable, crlTable, } for _, b := range tables { if err := db.CreateTable(b); err != nil { @@ -291,6 +322,19 @@ func (db *DB) GetCertificate(serialNumber string) (*x509.Certificate, error) { return cert, nil } +// GetCertificateData returns the data stored for a provisioner +func (db *DB) GetCertificateData(serialNumber string) (*CertificateData, error) { + b, err := db.Get(certsDataTable, []byte(serialNumber)) + if err != nil { + return nil, errors.Wrap(err, "database Get error") + } + var data CertificateData + if err := json.Unmarshal(b, &data); err != nil { + return nil, errors.Wrap(err, "error unmarshaling json") + } + return &data, nil +} + // StoreCertificate stores a certificate PEM. func (db *DB) StoreCertificate(crt *x509.Certificate) error { if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil { @@ -299,6 +343,47 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { return nil } +// CertificateData is the JSON representation of the data stored in +// x509_certs_data table. +type CertificateData struct { + Provisioner *ProvisionerData `json:"provisioner,omitempty"` +} + +// ProvisionerData is the JSON representation of the provisioner stored in the +// x509_certs_data table. +type ProvisionerData struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` +} + +// StoreCertificateChain stores the leaf certificate and the provisioner that +// authorized the certificate. +func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { + leaf := chain[0] + serialNumber := []byte(leaf.SerialNumber.String()) + data := &CertificateData{} + if p != nil { + data.Provisioner = &ProvisionerData{ + ID: p.GetID(), + Name: p.GetName(), + Type: p.GetType().String(), + } + } + b, err := json.Marshal(data) + if err != nil { + return errors.Wrap(err, "error marshaling json") + } + // Add certificate and certificate data in one transaction. + tx := new(database.Tx) + tx.Set(certsTable, serialNumber, leaf.Raw) + tx.Set(certsDataTable, serialNumber, b) + if err := db.Update(tx); err != nil { + return errors.Wrap(err, "database Update error") + } + return nil +} + // UseToken returns true if we were able to successfully store the token for // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { @@ -393,6 +478,7 @@ type MockAuthDB struct { MRevoke func(rci *RevokedCertificateInfo) error MRevokeSSH func(rci *RevokedCertificateInfo) error MGetCertificate func(serialNumber string) (*x509.Certificate, error) + MGetCertificateData func(serialNumber string) (*CertificateData, error) MStoreCertificate func(crt *x509.Certificate) error MUseToken func(id, tok string) (bool, error) MIsSSHHost func(principal string) (bool, error) @@ -464,6 +550,17 @@ func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, err return m.Ret1.(*x509.Certificate), m.Err } +// GetCertificateData mock. +func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) { + if m.MGetCertificateData != nil { + return m.MGetCertificateData(serialNumber) + } + if cd, ok := m.Ret1.(*CertificateData); ok { + return cd, m.Err + } + return nil, m.Err +} + // StoreCertificate mock. func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { if m.MStoreCertificate != nil { diff --git a/db/db_test.go b/db/db_test.go index 40f59215..b4515a5b 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,10 +1,15 @@ package db import ( + "crypto/x509" "errors" + "math/big" + "reflect" "testing" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) @@ -158,3 +163,133 @@ func TestUseToken(t *testing.T) { }) } } + +func TestDB_StoreCertificateChain(t *testing.T) { + p := &provisioner.JWK{ + ID: "some-id", + Name: "admin", + Type: "JWK", + } + chain := []*x509.Certificate{ + {Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)}, + } + type fields struct { + DB nosql.DB + isUp bool + } + type args struct { + p provisioner.Interface + chain []*x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Fatal("unexpected number of operations") + } + assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[0].Key) + assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) + assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[1].Key) + assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), tx.Operations[1].Value) + return nil + }, + }, true}, args{p, chain}, false}, + {"ok no provisioner", fields{&MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Fatal("unexpected number of operations") + } + assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[0].Key) + assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) + assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[1].Key) + assert.Equals(t, []byte(`{}`), tx.Operations[1].Value) + return nil + }, + }, true}, args{nil, chain}, false}, + {"fail store certificate", fields{&MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + return errors.New("test error") + }, + }, true}, args{p, chain}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &DB{ + DB: tt.fields.DB, + isUp: tt.fields.isUp, + } + if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr { + t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDB_GetCertificateData(t *testing.T) { + type fields struct { + DB nosql.DB + isUp bool + } + type args struct { + serialNumber string + } + tests := []struct { + name string + fields fields + args args + want *CertificateData + wantErr bool + }{ + {"ok", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, []byte("x509_certs_data")) + assert.Equals(t, key, []byte("1234")) + return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil + }, + }, true}, args{"1234"}, &CertificateData{ + Provisioner: &ProvisionerData{ + ID: "some-id", Name: "admin", Type: "JWK", + }, + }, false}, + {"fail not found", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, true}, args{"1234"}, nil, true}, + {"fail db", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("an error") + }, + }, true}, args{"1234"}, nil, true}, + {"fail unmarshal", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return []byte(`{"bad-json"}`), nil + }, + }, true}, args{"1234"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &DB{ + DB: tt.fields.DB, + isUp: tt.fields.isUp, + } + got, err := db.GetCertificateData(tt.args.serialNumber) + if (err != nil) != tt.wantErr { + t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/db/simple.go b/db/simple.go index c34cdb5d..6321e86f 100644 --- a/db/simple.go +++ b/db/simple.go @@ -20,7 +20,7 @@ type SimpleDB struct { usedTokens *sync.Map } -func newSimpleDB(c *Config) (AuthDB, error) { +func newSimpleDB(c *Config) (*SimpleDB, error) { db := &SimpleDB{} db.usedTokens = new(sync.Map) return db, nil diff --git a/docker/Dockerfile.step-ca b/docker/Dockerfile.step-ca index 9363b6ae..46677a91 100644 --- a/docker/Dockerfile.step-ca +++ b/docker/Dockerfile.step-ca @@ -3,15 +3,15 @@ FROM golang:alpine AS builder WORKDIR /src COPY . . -RUN apk add --no-cache \ - curl \ - git \ - make && \ - make V=1 bin/step-ca +RUN apk add --no-cache curl git make +RUN make V=1 bin/step-ca bin/step-awskms-init bin/step-cloudkms-init + FROM smallstep/step-cli:latest COPY --from=builder /src/bin/step-ca /usr/local/bin/step-ca +COPY --from=builder /src/bin/step-awskms-init /usr/local/bin/step-awskms-init +COPY --from=builder /src/bin/step-cloudkms-init /usr/local/bin/step-cloudkms-init USER root RUN apk add --no-cache libcap && setcap CAP_NET_BIND_SERVICE=+eip /usr/local/bin/step-ca diff --git a/docker/Dockerfile.step-ca.hsm b/docker/Dockerfile.step-ca.hsm new file mode 100644 index 00000000..ac59c909 --- /dev/null +++ b/docker/Dockerfile.step-ca.hsm @@ -0,0 +1,34 @@ +FROM golang:alpine AS builder + +WORKDIR /src +COPY . . + +RUN apk add --no-cache curl git make +RUN apk add --no-cache gcc musl-dev pkgconf pcsc-lite-dev +RUN make V=1 GOFLAGS="" build + + +FROM smallstep/step-cli:latest + +COPY --from=builder /src/bin/step-ca /usr/local/bin/step-ca +COPY --from=builder /src/bin/step-awskms-init /usr/local/bin/step-awskms-init +COPY --from=builder /src/bin/step-cloudkms-init /usr/local/bin/step-cloudkms-init +COPY --from=builder /src/bin/step-pkcs11-init /usr/local/bin/step-pkcs11-init +COPY --from=builder /src/bin/step-yubikey-init /usr/local/bin/step-yubikey-init + +USER root +RUN apk add --no-cache libcap && setcap CAP_NET_BIND_SERVICE=+eip /usr/local/bin/step-ca +RUN apk add --no-cache pcsc-lite pcsc-lite-libs +USER step + +ENV CONFIGPATH="/home/step/config/ca.json" +ENV PWDPATH="/home/step/secrets/password" + +VOLUME ["/home/step"] +STOPSIGNAL SIGTERM +HEALTHCHECK CMD step ca health 2>/dev/null | grep "^ok" >/dev/null + +COPY docker/entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] +CMD exec /usr/local/bin/step-ca --password-file $PWDPATH $CONFIGPATH diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 1f48c028..49d6b10c 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -53,6 +53,10 @@ function step_ca_init () { mv $STEPPATH/password $PWDPATH } +if [ -f /usr/sbin/pcscd ]; then + /usr/sbin/pcscd +fi + if [ ! -f "${STEPPATH}/config/ca.json" ]; then init_if_possible fi diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 84e968ab..67c5673d 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -654,7 +654,7 @@ preferably not all - meaning it never leaves the server on which it was created. ### Passwords -When you intialize your PKI (`step ca init`) the root and intermediate +When you initialize your PKI (`step ca init`) the root and intermediate private keys will be encrypted with the same password. We recommend that you change the password with which the intermediate is encrypted at your earliest convenience. @@ -681,7 +681,7 @@ to divide the root private key password across a handful of trusted parties. ### Provisioners -When you intialize your PKI (`step ca init`) a default provisioner will be created +When you initialize your PKI (`step ca init`) a default provisioner will be created and it's private key will be encrypted using the same password used to encrypt the root private key. Before deploying the Step CA you should remove this provisioner and add new ones that are encrypted with new, secure, random passwords. diff --git a/go.mod b/go.mod index 17ea33fe..546ec53d 100644 --- a/go.mod +++ b/go.mod @@ -14,18 +14,27 @@ require ( github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect github.com/Masterminds/sprig/v3 v3.2.2 github.com/ThalesIgnite/crypto11 v1.2.4 - github.com/aws/aws-sdk-go v1.30.29 + github.com/aws/aws-sdk-go v1.37.0 github.com/dgraph-io/ristretto v0.0.4-0.20200906165740-41ebdbffecfd // indirect + github.com/fatih/color v1.9.0 // indirect + github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect github.com/go-chi/chi v4.0.2+incompatible github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 + github.com/go-sql-driver/mysql v1.6.0 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.7 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.1.1 + github.com/hashicorp/vault/api v1.3.1 + github.com/hashicorp/vault/api/auth/approle v0.1.1 + github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 + github.com/jhump/protoreflect v1.9.0 // indirect github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-isatty v0.0.13 // indirect github.com/micromdm/scep/v2 v2.1.0 + github.com/miekg/pkcs11 v1.0.3 // indirect github.com/newrelic/go-agent v2.15.0+incompatible github.com/pkg/errors v0.9.1 github.com/rs/xid v1.2.1 @@ -33,22 +42,29 @@ require ( github.com/slackhq/nebula v1.5.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/nosql v0.4.0 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.7.1 github.com/urfave/cli v1.22.4 + go.etcd.io/bbolt v1.3.6 // indirect go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 - go.step.sm/crypto v0.16.1 - go.step.sm/linkedca v0.11.0 + go.step.sm/crypto v0.16.2 + go.step.sm/linkedca v0.16.1 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 - golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + golang.org/x/net v0.0.0-20220403103023-749bd193bc2b + golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 // indirect + golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect google.golang.org/api v0.70.0 - google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf - google.golang.org/grpc v1.44.0 - google.golang.org/protobuf v1.27.1 + google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de + google.golang.org/grpc v1.45.0 + google.golang.org/protobuf v1.28.0 gopkg.in/square/go-jose.v2 v2.6.0 + gopkg.in/yaml.v3 v3.0.0 // indirect ) // replace github.com/smallstep/nosql => ../nosql // replace go.step.sm/crypto => ../crypto // replace go.step.sm/cli-utils => ../cli-utils // replace go.step.sm/linkedca => ../linkedca + +// use github.com/smallstep/pkcs7 fork with patches applied +replace go.mozilla.org/pkcs7 => github.com/smallstep/pkcs7 v0.0.0-20211016004704-52592125d6f6 diff --git a/go.sum b/go.sum index e7ddd660..32a27e27 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,7 @@ github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUM github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/Masterminds/goutils v1.1.0/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= @@ -132,15 +133,18 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5 github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= +github.com/armon/go-metrics v0.3.9 h1:O2sNqxBdvq8Eq5xmzljcYzAORli6RWCvEym4cJf9m18= +github.com/armon/go-metrics v0.3.9/go.mod h1:4O98XIr/9W0sxpJ8UaYkvjk10Iff7SnFrb4QAOwNTFc= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.30.29 h1:NXNqBS9hjOCpDL8SyCyl38gZX3LLLunKOJc5E7vJ8P0= -github.com/aws/aws-sdk-go v1.30.29/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go v1.37.0 h1:GzFnhOIsrGyQ69s7VgqtrG2BG8v7X7vwB3Xpbd/DBBk= +github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -149,7 +153,10 @@ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kB github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/cenkalti/backoff/v3 v3.0.0 h1:ske+9nBpD9qZsTBoF41nW5L+AIuFBKMeze18XQ3eG1c= +github.com/cenkalti/backoff/v3 v3.0.0/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= @@ -161,6 +168,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5O github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= +github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -188,6 +197,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -224,14 +234,24 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch/v5 v5.5.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= +github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c= +github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= +github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y= +github.com/frankban/quicktest v1.13.0 h1:yNZif1OkDfNoDfb9zZa9aXIpejNR4F23Wely0c+Qdqk= +github.com/frankban/quicktest v1.13.0/go.mod h1:qLE0fzW0VuyUAJgPU19zByoIr0HtCHN/r/VLSOOIySU= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-asn1-ber/asn1-ber v1.3.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-chi/chi v4.0.2+incompatible h1:maB6vn6FqCxrpz4FqWdh4+lwpyZIQS7YEAUcHlgXVRs= github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= @@ -243,6 +263,7 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-kit/kit v0.10.0 h1:dXFJfIHVvUcpSgDOV+Ne6t7jXri8Tfv2uOLHUZ2XNuo= github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-ldap/ldap/v3 v3.1.10/go.mod h1:5Zun81jBTabRaI8lzN7E1JjyEl1g6zI6u9pd8luAK4Q= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= @@ -250,11 +271,14 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG github.com/go-piv/piv-go v1.7.0 h1:rfjdFdASfGV5KLJhSjgpGJ5lzVZVtRWn8ovy/H9HQ/U= github.com/go-piv/piv-go v1.7.0/go.mod h1:ON2WvQncm7dIkCQ7kYJs+nc3V4jHGfrrJnSF8HKy7Gk= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= +github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= @@ -266,8 +290,9 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -297,8 +322,9 @@ github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -347,6 +373,7 @@ github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pf github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU= github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.4.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -361,24 +388,73 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= +github.com/hashicorp/go-hclog v0.16.2 h1:K4ev2ib4LdQETX5cSZBG0DVLk1jwGqSPXBjdah3veNs= +github.com/hashicorp/go-hclog v0.16.2/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= +github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-kms-wrapping/entropy v0.1.0/go.mod h1:d1g9WGtAunDNpek8jUIEJnBlbgKS1N2Q61QkHiZyR1g= github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM= +github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= +github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-retryablehttp v0.6.6 h1:HJunrbHTDDbBb/ay4kxa1n+dLmttUlnP3V9oNE4hmsM= +github.com/hashicorp/go-retryablehttp v0.6.6/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= +github.com/hashicorp/go-secure-stdlib/mlock v0.1.1 h1:cCRo8gK7oq6A2L6LICkUZ+/a5rLiRXFMf1Qd4xSwxTc= +github.com/hashicorp/go-secure-stdlib/mlock v0.1.1/go.mod h1:zq93CJChV6L9QTfGKtfBxKqD7BqqXx5O04A/ns2p5+I= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1 h1:78ki3QBevHwYrVxnyVeaEz+7WtifHhauYF23es/0KlI= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= +github.com/hashicorp/go-secure-stdlib/password v0.1.1/go.mod h1:9hH302QllNwu1o2TGYtSk8I8kTAN0ca1EHpwhm5Mmzo= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 h1:nd0HIW15E6FG1MsnArYaHfuw9C2zgzM8LxkG5Ty/788= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= +github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.1/go.mod h1:l8slYwnJA26yBz+ErHpp2IRCLr0vuOMGBORIz4rRiAs= github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= +github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.2.0 h1:3vNe/fWF5CBgRIguda1meWhsZHy3m8gCJ5wx+dIzX/E= github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hashicorp/vault/api v1.3.0/go.mod h1:EabNQLI0VWbWoGlA+oBLC8PXmR9D60aUVgQGvangFWQ= +github.com/hashicorp/vault/api v1.3.1 h1:pkDkcgTh47PRjY1NEFeofqR4W/HkNUi9qIakESO2aRM= +github.com/hashicorp/vault/api v1.3.1/go.mod h1:QeJoWxMFt+MsuWcYhmwRLwKEXrjwAFFywzhptMsTIUw= +github.com/hashicorp/vault/api/auth/approle v0.1.1 h1:R5yA+xcNvw1ix6bDuWOaLOq2L4L77zDCVsethNw97xQ= +github.com/hashicorp/vault/api/auth/approle v0.1.1/go.mod h1:mHOLgh//xDx4dpqXoq6tS8Ob0FoCFWLU2ibJ26Lfmag= +github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 h1:6BtyahbF4aQp8gg3ww0A/oIoqzbhpNP1spXU3nHE0n0= +github.com/hashicorp/vault/api/auth/kubernetes v0.1.0/go.mod h1:Pdgk78uIs0mgDOLvc3a+h/vYIT9rznw2sz+ucuH9024= +github.com/hashicorp/vault/sdk v0.3.0 h1:kR3dpxNkhh/wr6ycaJYqp6AFT/i2xaftbfnwZduTKEY= +github.com/hashicorp/vault/sdk v0.3.0/go.mod h1:aZ3fNuL5VNydQk8GcLJ2TV8YCRVvyaakYkhZRoVuhj0= +github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= +github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.3.2 h1:L18LIDzqlW6xN2rEkpdV8+oL/IXWJ1APd+vsdYy4Wdw= @@ -441,14 +517,21 @@ github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= +github.com/jhump/protoreflect v1.9.0 h1:npqHz788dryJiR/l6K/RUQAyh2SwV91+d1dnh4RjO9w= +github.com/jhump/protoreflect v1.9.0/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= -github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -472,8 +555,9 @@ github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -498,6 +582,7 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.13 h1:qdl+GuBjcsKKDco5BsxPJlId98mSWNKqYA+Co0SC1yA= @@ -510,8 +595,9 @@ github.com/micromdm/scep/v2 v2.1.0 h1:2fS9Rla7qRR266hvUoEauBJ7J6FhgssEiq2OkSKXma github.com/micromdm/scep/v2 v2.1.0/go.mod h1:BkF7TkPPhmgJAMtHfP+sFTKXmgzNJgLQlvvGoOExBcc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4= -github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f h1:eVB9ELsoq5ouItQBr5Tj334bhPJG/MX+m7rTchmzVUQ= github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/pkcs11 v1.0.3 h1:iMwmD7I5225wv84WxIG/bmxz9AXjWvTWIbM/TYHvWtw= +github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= @@ -519,11 +605,17 @@ github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HK github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdIPrefOvVG1VZ96U0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.4.2 h1:6h7AQ0yhTcIsmFmnAwQls75jp2Gzs4iB8W7pjMO+rqo= +github.com/mitchellh/mapstructure v1.4.2/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= @@ -543,7 +635,9 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f/go.mod h1:nwPd6pDNId/Xi16qtKrFHrauSwMNuvk+zcjk89wrnlA= github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU= github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ= +github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso= github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= +github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -560,11 +654,15 @@ github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnh github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= +github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pierrec/lz4 v2.5.2+incompatible h1:WCjObylUIOlKy/+7Abdn34TLIkXiA4UWUMhxq9m9ZXI= +github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -577,6 +675,7 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= @@ -588,6 +687,7 @@ github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6T github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= +github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= @@ -612,6 +712,9 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= @@ -634,6 +737,8 @@ github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/nosql v0.4.0 h1:Go3WYwttUuvwqMtFiiU4g7kBIlY+hR0bIZAqVdakQ3M= github.com/smallstep/nosql v0.4.0/go.mod h1:yKZT5h7cdIVm6wEKM9+jN5dgK80Hljpuy8HNsnI7Gzo= +github.com/smallstep/pkcs7 v0.0.0-20211016004704-52592125d6f6 h1:8Rjy6IZbSM/jcYgBWCoLIGjug7QcoLtF9sUuhDrHD2U= +github.com/smallstep/pkcs7 v0.0.0-20211016004704-52592125d6f6/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -658,17 +763,20 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3 github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= @@ -688,12 +796,10 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= +go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= -go.mozilla.org/pkcs7 v0.0.0-20210730143726-725912489c62/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= -go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 h1:CCriYyAfq1Br1aIYettdHZTy8mBTIPo7We18TuO/bak= -go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -708,14 +814,16 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= -go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk= -go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.11.0 h1:jkG5XDQz9VSz2PH+cGjDvJTwiIziN0SWExTnicWpb8o= -go.step.sm/linkedca v0.11.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= +go.step.sm/crypto v0.16.2 h1:Pr9aazTwWBBZNogUsOqhOrPSdwAa9pPs+lMB602lnDA= +go.step.sm/crypto v0.16.2/go.mod h1:1WkTOTY+fOX/RY4TnZREp6trQAsBHRQ7nu6QJBiNQF8= +go.step.sm/linkedca v0.16.1 h1:CdbMV5SjnlRsgeYTXaaZmQCkYIgJq8BOzpewri57M2k= +go.step.sm/linkedca v0.16.1/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -782,6 +890,7 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20170726083632-f5079bd7f6f7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180530234432-1e491301e022/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -830,8 +939,9 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220403103023-749bd193bc2b h1:vI32FkLJNAWtGD4BwkThwEy6XS7ZLLMHkSkYfF8M0W0= +golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -886,6 +996,7 @@ golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -909,6 +1020,7 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -943,8 +1055,9 @@ golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220209214540-3681064d5158 h1:rm+CHSpPEEW2IsXUib1ThaHIjuBVZjxNgSKmBLFfD4c= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= @@ -964,6 +1077,9 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1008,8 +1124,10 @@ golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWc golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= @@ -1079,6 +1197,7 @@ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20170818010345-ee236bd376b0/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1143,8 +1262,10 @@ google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ6 google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220207164111-0872dc986b00/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI= -google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf h1:SVYXkUz2yZS9FWb2Gm8ivSlbNQzL2Z/NpPKE3RG2jWk= google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI= +google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de h1:9Ti5SG2U4cAcluryUo/sFay3TQKoxiFMfaT0pbizU7k= +google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= +google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= @@ -1176,8 +1297,10 @@ google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= -google.golang.org/grpc v1.44.0 h1:weqSxi/TMs1SqFRMHCtBgXRs8k3X39QIDEZ0pRcttUg= +google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= google.golang.org/grpc v1.44.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= +google.golang.org/grpc v1.45.0 h1:NEpgUqV3Z+ZjkqMsxMg11IaDrXY4RY6CQukSGK0uI1M= +google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -1189,10 +1312,12 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1217,11 +1342,13 @@ gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/kms/azurekms/internal/mock/key_vault_client.go b/kms/azurekms/internal/mock/key_vault_client.go index 42bd55fd..37858854 100644 --- a/kms/azurekms/internal/mock/key_vault_client.go +++ b/kms/azurekms/internal/mock/key_vault_client.go @@ -6,9 +6,10 @@ package mock import ( context "context" + reflect "reflect" + keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" gomock "github.com/golang/mock/gomock" - reflect "reflect" ) // KeyVaultClient is a mock of KeyVaultClient interface diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index 65d06048..2f74f1ad 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -279,7 +279,8 @@ func (k *CloudKMS) createKeyRingIfNeeded(name string) error { // GetPublicKey gets from Google's Cloud KMS a public key by name. Key names // follow the pattern: -// projects/([^/]+)/locations/([a-zA-Z0-9_-]{1,63})/keyRings/([a-zA-Z0-9_-]{1,63})/cryptoKeys/([a-zA-Z0-9_-]{1,63})/cryptoKeyVersions/([a-zA-Z0-9_-]{1,63}) +// +// projects/([^/]+)/locations/([a-zA-Z0-9_-]{1,63})/keyRings/([a-zA-Z0-9_-]{1,63})/cryptoKeys/([a-zA-Z0-9_-]{1,63})/cryptoKeyVersions/([a-zA-Z0-9_-]{1,63}) func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { if req.Name == "" { return nil, errors.New("createKeyRequest 'name' cannot be empty") diff --git a/kms/pkcs11/opensc_test.go b/kms/pkcs11/opensc_test.go index b365e614..365c075c 100644 --- a/kms/pkcs11/opensc_test.go +++ b/kms/pkcs11/opensc_test.go @@ -14,12 +14,15 @@ var softHSM2Once sync.Once // mustPKCS11 configures a *PKCS11 KMS to be used with OpenSC, using for example // a Nitrokey HSM. To initialize these tests we should run: -// sc-hsm-tool --initialize --so-pin 3537363231383830 --pin 123456 -// Or: -// pkcs11-tool --module /usr/local/lib/opensc-pkcs11.so \ -// --init-token --init-pin \ -// --so-pin=3537363231383830 --new-pin=123456 --pin=123456 \ -// --label="pkcs11-test" +// +// sc-hsm-tool --initialize --so-pin 3537363231383830 --pin 123456 +// +// Or: +// +// pkcs11-tool --module /usr/local/lib/opensc-pkcs11.so \ +// --init-token --init-pin \ +// --so-pin=3537363231383830 --new-pin=123456 --pin=123456 \ +// --label="pkcs11-test" func mustPKCS11(t TBTesting) *PKCS11 { t.Helper() testModule = "OpenSC" diff --git a/kms/pkcs11/softhsm2_test.go b/kms/pkcs11/softhsm2_test.go index ed2ff208..6fc0c248 100644 --- a/kms/pkcs11/softhsm2_test.go +++ b/kms/pkcs11/softhsm2_test.go @@ -14,12 +14,14 @@ var softHSM2Once sync.Once // mustPKCS11 configures a *PKCS11 KMS to be used with SoftHSM2. To initialize // these tests, we should run: -// softhsm2-util --init-token --free \ -// --token pkcs11-test --label pkcs11-test \ -// --so-pin password --pin password +// +// softhsm2-util --init-token --free \ +// --token pkcs11-test --label pkcs11-test \ +// --so-pin password --pin password // // To delete we should run: -// softhsm2-util --delete-token --token pkcs11-test +// +// softhsm2-util --delete-token --token pkcs11-test func mustPKCS11(t TBTesting) *PKCS11 { t.Helper() testModule = "SoftHSM2" diff --git a/kms/pkcs11/yubihsm2_test.go b/kms/pkcs11/yubihsm2_test.go index 281aff54..49eb13d1 100644 --- a/kms/pkcs11/yubihsm2_test.go +++ b/kms/pkcs11/yubihsm2_test.go @@ -14,7 +14,8 @@ var yubiHSM2Once sync.Once // mustPKCS11 configures a *PKCS11 KMS to be used with YubiHSM2. To initialize // these tests, we should run: -// yubihsm-connector -d +// +// yubihsm-connector -d func mustPKCS11(t TBTesting) *PKCS11 { t.Helper() testModule = "YubiHSM2" diff --git a/kms/uri/uri_119_test.go b/kms/uri/uri_119_test.go new file mode 100644 index 00000000..af8f9939 --- /dev/null +++ b/kms/uri/uri_119_test.go @@ -0,0 +1,62 @@ +//go:build go1.19 + +package uri + +import ( + "net/url" + "reflect" + "testing" +) + +func TestParse(t *testing.T) { + type args struct { + rawuri string + } + tests := []struct { + name string + args args + want *URI + wantErr bool + }{ + {"ok", args{"yubikey:slot-id=9a"}, &URI{ + URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"}, + Values: url.Values{"slot-id": []string{"9a"}}, + }, false}, + {"ok schema", args{"cloudkms:"}, &URI{ + URL: &url.URL{Scheme: "cloudkms"}, + Values: url.Values{}, + }, false}, + {"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{ + URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"}, + Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}, + }, false}, + {"ok file", args{"file:///tmp/ca.cert"}, &URI{ + URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"}, + Values: url.Values{}, + }, false}, + {"ok file simple", args{"file:/tmp/ca.cert"}, &URI{ + URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert", OmitHost: true}, + Values: url.Values{}, + }, false}, + {"ok file host", args{"file://tmp/ca.cert"}, &URI{ + URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"}, + Values: url.Values{}, + }, false}, + {"fail schema", args{"cloudkms"}, nil, true}, + {"fail parse", args{"yubi%key:slot-id=9a"}, nil, true}, + {"fail scheme", args{"yubikey"}, nil, true}, + {"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Parse(tt.args.rawuri) + if (err != nil) != tt.wantErr { + t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Parse() = %#v, want %#v", got.URL, tt.want.URL) + } + }) + } +} diff --git a/kms/uri/uri_other_test.go b/kms/uri/uri_other_test.go new file mode 100644 index 00000000..dec50f55 --- /dev/null +++ b/kms/uri/uri_other_test.go @@ -0,0 +1,62 @@ +//go:build !go1.19 + +package uri + +import ( + "net/url" + "reflect" + "testing" +) + +func TestParse(t *testing.T) { + type args struct { + rawuri string + } + tests := []struct { + name string + args args + want *URI + wantErr bool + }{ + {"ok", args{"yubikey:slot-id=9a"}, &URI{ + URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"}, + Values: url.Values{"slot-id": []string{"9a"}}, + }, false}, + {"ok schema", args{"cloudkms:"}, &URI{ + URL: &url.URL{Scheme: "cloudkms"}, + Values: url.Values{}, + }, false}, + {"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{ + URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"}, + Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}, + }, false}, + {"ok file", args{"file:///tmp/ca.cert"}, &URI{ + URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"}, + Values: url.Values{}, + }, false}, + {"ok file simple", args{"file:/tmp/ca.cert"}, &URI{ + URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"}, + Values: url.Values{}, + }, false}, + {"ok file host", args{"file://tmp/ca.cert"}, &URI{ + URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"}, + Values: url.Values{}, + }, false}, + {"fail schema", args{"cloudkms"}, nil, true}, + {"fail parse", args{"yubi%key:slot-id=9a"}, nil, true}, + {"fail scheme", args{"yubikey"}, nil, true}, + {"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Parse(tt.args.rawuri) + if (err != nil) != tt.wantErr { + t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Parse() = %#v, want %#v", got.URL, tt.want.URL) + } + }) + } +} diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index 01fbad0f..2ffb5f3d 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -85,59 +85,6 @@ func TestHasScheme(t *testing.T) { } } -func TestParse(t *testing.T) { - type args struct { - rawuri string - } - tests := []struct { - name string - args args - want *URI - wantErr bool - }{ - {"ok", args{"yubikey:slot-id=9a"}, &URI{ - URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a"}, - Values: url.Values{"slot-id": []string{"9a"}}, - }, false}, - {"ok schema", args{"cloudkms:"}, &URI{ - URL: &url.URL{Scheme: "cloudkms"}, - Values: url.Values{}, - }, false}, - {"ok query", args{"yubikey:slot-id=9a;foo=bar?pin=123456&foo=bar"}, &URI{ - URL: &url.URL{Scheme: "yubikey", Opaque: "slot-id=9a;foo=bar", RawQuery: "pin=123456&foo=bar"}, - Values: url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}, - }, false}, - {"ok file", args{"file:///tmp/ca.cert"}, &URI{ - URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"}, - Values: url.Values{}, - }, false}, - {"ok file simple", args{"file:/tmp/ca.cert"}, &URI{ - URL: &url.URL{Scheme: "file", Path: "/tmp/ca.cert"}, - Values: url.Values{}, - }, false}, - {"ok file host", args{"file://tmp/ca.cert"}, &URI{ - URL: &url.URL{Scheme: "file", Host: "tmp", Path: "/ca.cert"}, - Values: url.Values{}, - }, false}, - {"fail schema", args{"cloudkms"}, nil, true}, - {"fail parse", args{"yubi%key:slot-id=9a"}, nil, true}, - {"fail scheme", args{"yubikey"}, nil, true}, - {"fail parse opaque", args{"yubikey:slot-id=%ZZ"}, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := Parse(tt.args.rawuri) - if (err != nil) != tt.wantErr { - t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Parse() = %#v, want %v", got.URL, tt.want) - } - }) - } -} - func TestParseWithScheme(t *testing.T) { type args struct { scheme string diff --git a/logging/clf.go b/logging/clf.go index cee6c982..0e4d9ae9 100644 --- a/logging/clf.go +++ b/logging/clf.go @@ -19,7 +19,9 @@ type CommonLogFormat struct{} // Format implements the logrus.Formatter interface. It returns the given // logrus entry as a CLF line with the following format: -//